add files
This commit is contained in:
637
unsloth_compiled_cache/UnslothAlignPropTrainer.py
Normal file
637
unsloth_compiled_cache/UnslothAlignPropTrainer.py
Normal file
@@ -0,0 +1,637 @@
|
||||
"""
|
||||
2025.6.1
|
||||
2025.6.2
|
||||
4.52.4
|
||||
0.18.2
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from trl.trainer.alignprop_trainer import (Accelerator, AlignPropConfig, AlignPropTrainer, Any, Callable, DDPOStableDiffusionPipeline, Optional, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warn)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def selective_log_softmax(logits, index):
|
||||
logits = logits.to(torch.float32)
|
||||
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
||||
# loop to reduce peak mem consumption
|
||||
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
||||
return per_token_logps
|
||||
@dataclass
|
||||
class UnslothAlignPropConfig(AlignPropConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`AlignPropTrainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
||||
Name of this experiment (defaults to the file name without the extension).
|
||||
run_name (`str`, *optional*, defaults to `""`):
|
||||
Name of this run.
|
||||
seed (`int`, *optional*, defaults to `0`):
|
||||
Random seed for reproducibility.
|
||||
log_with (`str` or `None`, *optional*, defaults to `None`):
|
||||
Log with either `"wandb"` or `"tensorboard"`. Check
|
||||
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
||||
log_image_freq (`int`, *optional*, defaults to `1`):
|
||||
Frequency for logging images.
|
||||
tracker_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the tracker (e.g., `wandb_project`).
|
||||
accelerator_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator.
|
||||
project_kwargs (`dict[str, Any]`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator project config (e.g., `logging_dir`).
|
||||
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
||||
Name of project to use for tracking.
|
||||
logdir (`str`, *optional*, defaults to `"logs"`):
|
||||
Top-level logging directory for checkpoint saving.
|
||||
num_epochs (`int`, *optional*, defaults to `100`):
|
||||
Number of epochs to train.
|
||||
save_freq (`int`, *optional*, defaults to `1`):
|
||||
Number of epochs between saving model checkpoints.
|
||||
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
||||
Number of checkpoints to keep before overwriting old ones.
|
||||
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
||||
Mixed precision training.
|
||||
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
||||
Allow `tf32` on Ampere GPUs.
|
||||
resume_from (`str`, *optional*, defaults to `""`):
|
||||
Path to resume training from a checkpoint.
|
||||
sample_num_steps (`int`, *optional*, defaults to `50`):
|
||||
Number of sampler inference steps.
|
||||
sample_eta (`float`, *optional*, defaults to `1.0`):
|
||||
Eta parameter for the DDIM sampler.
|
||||
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
||||
Classifier-free guidance weight.
|
||||
train_batch_size (`int`, *optional*, defaults to `1`):
|
||||
Batch size for training.
|
||||
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the 8bit Adam optimizer from `bitsandbytes`.
|
||||
train_learning_rate (`float`, *optional*, defaults to `1e-3`):
|
||||
Learning rate.
|
||||
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
||||
Beta1 for Adam optimizer.
|
||||
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
||||
Beta2 for Adam optimizer.
|
||||
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
||||
Weight decay for Adam optimizer.
|
||||
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
||||
Epsilon value for Adam optimizer.
|
||||
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
||||
Number of gradient accumulation steps.
|
||||
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
||||
Maximum gradient norm for gradient clipping.
|
||||
negative_prompts (`str` or `None`, *optional*, defaults to `None`):
|
||||
Comma-separated list of prompts to use as negative examples.
|
||||
truncated_backprop_rand (`bool`, *optional*, defaults to `True`):
|
||||
If `True`, randomized truncation to different diffusion timesteps is used.
|
||||
truncated_backprop_timestep (`int`, *optional*, defaults to `49`):
|
||||
Absolute timestep to which the gradients are backpropagated. Used only if `truncated_backprop_rand=False`.
|
||||
truncated_rand_backprop_minmax (`tuple[int, int]`, *optional*, defaults to `(0, 50)`):
|
||||
Range of diffusion timesteps for randomized truncated backpropagation.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the final model to the Hub.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
exp_name = 'test',
|
||||
run_name = '',
|
||||
seed = 3407,
|
||||
log_with = None,
|
||||
log_image_freq = 1,
|
||||
tracker_project_name = 'trl',
|
||||
logdir = 'logs',
|
||||
num_epochs = 100,
|
||||
save_freq = 1,
|
||||
num_checkpoint_limit = 5,
|
||||
mixed_precision = 'fp16',
|
||||
allow_tf32 = True,
|
||||
resume_from = '',
|
||||
sample_num_steps = 50,
|
||||
sample_eta = 1.0,
|
||||
sample_guidance_scale = 5.0,
|
||||
train_batch_size = 1,
|
||||
train_use_8bit_adam = False,
|
||||
train_learning_rate = 5e-05,
|
||||
train_adam_beta1 = 0.9,
|
||||
train_adam_beta2 = 0.999,
|
||||
train_adam_weight_decay = 0.01,
|
||||
train_adam_epsilon = 1e-08,
|
||||
train_gradient_accumulation_steps = 2,
|
||||
train_max_grad_norm = 1.0,
|
||||
negative_prompts = None,
|
||||
truncated_backprop_rand = True,
|
||||
truncated_backprop_timestep = 49,
|
||||
push_to_hub = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
exp_name = exp_name,
|
||||
run_name = run_name,
|
||||
seed = seed,
|
||||
log_with = log_with,
|
||||
log_image_freq = log_image_freq,
|
||||
tracker_project_name = tracker_project_name,
|
||||
logdir = logdir,
|
||||
num_epochs = num_epochs,
|
||||
save_freq = save_freq,
|
||||
num_checkpoint_limit = num_checkpoint_limit,
|
||||
mixed_precision = mixed_precision,
|
||||
allow_tf32 = allow_tf32,
|
||||
resume_from = resume_from,
|
||||
sample_num_steps = sample_num_steps,
|
||||
sample_eta = sample_eta,
|
||||
sample_guidance_scale = sample_guidance_scale,
|
||||
train_batch_size = train_batch_size,
|
||||
train_use_8bit_adam = train_use_8bit_adam,
|
||||
train_learning_rate = train_learning_rate,
|
||||
train_adam_beta1 = train_adam_beta1,
|
||||
train_adam_beta2 = train_adam_beta2,
|
||||
train_adam_weight_decay = train_adam_weight_decay,
|
||||
train_adam_epsilon = train_adam_epsilon,
|
||||
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
||||
train_max_grad_norm = train_max_grad_norm,
|
||||
negative_prompts = negative_prompts,
|
||||
truncated_backprop_rand = truncated_backprop_rand,
|
||||
truncated_backprop_timestep = truncated_backprop_timestep,
|
||||
push_to_hub = push_to_hub,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothAlignPropTrainer(PyTorchModelHubMixin):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "alignprop"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AlignPropConfig,
|
||||
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
||||
prompt_function: Callable[[], tuple[str, Any]],
|
||||
sd_pipeline: DDPOStableDiffusionPipeline,
|
||||
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
||||
):
|
||||
if image_samples_hook is None:
|
||||
warn("No image_samples_hook provided; no images will be logged")
|
||||
|
||||
self.prompt_fn = prompt_function
|
||||
self.reward_fn = reward_function
|
||||
self.config = config
|
||||
self.image_samples_callback = image_samples_hook
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
||||
|
||||
if self.config.resume_from:
|
||||
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
||||
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
||||
# get the most recent checkpoint in this directory
|
||||
checkpoints = list(
|
||||
filter(
|
||||
lambda x: "checkpoint_" in x,
|
||||
os.listdir(self.config.resume_from),
|
||||
)
|
||||
)
|
||||
if len(checkpoints) == 0:
|
||||
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
||||
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
||||
self.config.resume_from = os.path.join(
|
||||
self.config.resume_from,
|
||||
f"checkpoint_{checkpoint_numbers[-1]}",
|
||||
)
|
||||
|
||||
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with=self.config.log_with,
|
||||
mixed_precision=self.config.mixed_precision,
|
||||
project_config=accelerator_project_config,
|
||||
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
||||
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
||||
# the total number of optimizer steps to accumulate across.
|
||||
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps,
|
||||
**self.config.accelerator_kwargs,
|
||||
)
|
||||
|
||||
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.accelerator.init_trackers(
|
||||
self.config.tracker_project_name,
|
||||
config=dict(alignprop_trainer_config=config.to_dict())
|
||||
if not is_using_tensorboard
|
||||
else config.to_dict(),
|
||||
init_kwargs=self.config.tracker_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
set_seed(self.config.seed, device_specific=True)
|
||||
|
||||
self.sd_pipeline = sd_pipeline
|
||||
|
||||
self.sd_pipeline.set_progress_bar_config(
|
||||
position=1,
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
leave=False,
|
||||
desc="Timestep",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
if self.accelerator.mixed_precision == "fp16":
|
||||
inference_dtype = torch.float16
|
||||
elif self.accelerator.mixed_precision == "bf16":
|
||||
inference_dtype = torch.bfloat16
|
||||
else:
|
||||
inference_dtype = torch.float32
|
||||
|
||||
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
||||
|
||||
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if self.config.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
self.optimizer = self._setup_optimizer(
|
||||
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
||||
)
|
||||
|
||||
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
||||
self.sd_pipeline.tokenizer(
|
||||
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
|
||||
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
||||
|
||||
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
||||
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
else:
|
||||
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
|
||||
if config.resume_from:
|
||||
logger.info(f"Resuming from {config.resume_from}")
|
||||
self.accelerator.load_state(config.resume_from)
|
||||
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
||||
else:
|
||||
self.first_epoch = 0
|
||||
|
||||
def compute_rewards(self, prompt_image_pairs):
|
||||
reward, reward_metadata = self.reward_fn(
|
||||
prompt_image_pairs["images"], prompt_image_pairs["prompts"], prompt_image_pairs["prompt_metadata"]
|
||||
)
|
||||
return reward
|
||||
|
||||
def step(self, epoch: int, global_step: int):
|
||||
"""
|
||||
Perform a single step of training.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch.
|
||||
global_step (int): The current global step.
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step.
|
||||
"""
|
||||
info = defaultdict(list)
|
||||
|
||||
self.sd_pipeline.unet.train()
|
||||
|
||||
for _ in range(self.config.train_gradient_accumulation_steps):
|
||||
with self.accelerator.accumulate(self.sd_pipeline.unet), self.autocast(), torch.enable_grad():
|
||||
prompt_image_pairs = self._generate_samples(
|
||||
batch_size=self.config.train_batch_size,
|
||||
)
|
||||
|
||||
rewards = self.compute_rewards(prompt_image_pairs)
|
||||
|
||||
prompt_image_pairs["rewards"] = rewards
|
||||
|
||||
rewards_vis = self.accelerator.gather(rewards).detach().cpu().numpy()
|
||||
|
||||
loss = self.calculate_loss(rewards)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if self.accelerator.sync_gradients:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
self.trainable_layers.parameters()
|
||||
if not isinstance(self.trainable_layers, list)
|
||||
else self.trainable_layers,
|
||||
self.config.train_max_grad_norm,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
info["reward_mean"].append(rewards_vis.mean())
|
||||
info["reward_std"].append(rewards_vis.std())
|
||||
info["loss"].append(loss.item())
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if self.accelerator.sync_gradients:
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.tensor(v)) for k, v in info.items()}
|
||||
info = self.accelerator.reduce(info, reduction="mean")
|
||||
info.update({"epoch": epoch})
|
||||
self.accelerator.log(info, step=global_step)
|
||||
global_step += 1
|
||||
info = defaultdict(list)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
||||
)
|
||||
# Logs generated images
|
||||
if self.image_samples_callback is not None and global_step % self.config.log_image_freq == 0:
|
||||
self.image_samples_callback(prompt_image_pairs, global_step, self.accelerator.trackers[0])
|
||||
|
||||
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
||||
self.accelerator.save_state()
|
||||
|
||||
return global_step
|
||||
|
||||
def calculate_loss(self, rewards):
|
||||
"""
|
||||
Calculate the loss for a batch of an unpacked sample
|
||||
|
||||
Args:
|
||||
rewards (torch.Tensor):
|
||||
Differentiable reward scalars for each generated image, shape: [batch_size]
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor)
|
||||
(all of these are of shape (1,))
|
||||
"""
|
||||
# Loss is specific to Aesthetic Reward function used in AlignProp (https://huggingface.co/papers/2310.03739)
|
||||
loss = 10.0 - (rewards).mean()
|
||||
return loss
|
||||
|
||||
def loss(
|
||||
self,
|
||||
advantages: torch.Tensor,
|
||||
clip_range: float,
|
||||
ratio: torch.Tensor,
|
||||
):
|
||||
unclipped_loss = -advantages * ratio
|
||||
clipped_loss = -advantages * torch.clamp(
|
||||
ratio,
|
||||
1.0 - clip_range,
|
||||
1.0 + clip_range,
|
||||
)
|
||||
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
||||
|
||||
def _setup_optimizer(self, trainable_layers_parameters):
|
||||
if self.config.train_use_8bit_adam:
|
||||
import bitsandbytes
|
||||
|
||||
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
return optimizer_cls(
|
||||
trainable_layers_parameters,
|
||||
lr=self.config.train_learning_rate,
|
||||
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
||||
weight_decay=self.config.train_adam_weight_decay,
|
||||
eps=self.config.train_adam_epsilon,
|
||||
)
|
||||
|
||||
def _save_model_hook(self, models, weights, output_dir):
|
||||
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
||||
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
||||
|
||||
def _load_model_hook(self, models, input_dir):
|
||||
self.sd_pipeline.load_checkpoint(models, input_dir)
|
||||
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
||||
|
||||
def _generate_samples(self, batch_size, with_grad=True, prompts=None):
|
||||
"""
|
||||
Generate samples from the model
|
||||
|
||||
Args:
|
||||
batch_size (int): Batch size to use for sampling
|
||||
with_grad (bool): Whether the generated RGBs should have gradients attached to it.
|
||||
|
||||
Returns:
|
||||
prompt_image_pairs (dict[Any])
|
||||
"""
|
||||
prompt_image_pairs = {}
|
||||
|
||||
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
||||
|
||||
if prompts is None:
|
||||
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
||||
else:
|
||||
prompt_metadata = [{} for _ in range(batch_size)]
|
||||
|
||||
prompt_ids = self.sd_pipeline.tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
|
||||
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
||||
|
||||
if with_grad:
|
||||
sd_output = self.sd_pipeline.rgb_with_grad(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=sample_neg_prompt_embeds,
|
||||
num_inference_steps=self.config.sample_num_steps,
|
||||
guidance_scale=self.config.sample_guidance_scale,
|
||||
eta=self.config.sample_eta,
|
||||
truncated_backprop_rand=self.config.truncated_backprop_rand,
|
||||
truncated_backprop_timestep=self.config.truncated_backprop_timestep,
|
||||
truncated_rand_backprop_minmax=self.config.truncated_rand_backprop_minmax,
|
||||
output_type="pt",
|
||||
)
|
||||
else:
|
||||
sd_output = self.sd_pipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=sample_neg_prompt_embeds,
|
||||
num_inference_steps=self.config.sample_num_steps,
|
||||
guidance_scale=self.config.sample_guidance_scale,
|
||||
eta=self.config.sample_eta,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
images = sd_output.images
|
||||
|
||||
prompt_image_pairs["images"] = images
|
||||
prompt_image_pairs["prompts"] = prompts
|
||||
prompt_image_pairs["prompt_metadata"] = prompt_metadata
|
||||
|
||||
return prompt_image_pairs
|
||||
|
||||
def train(self, epochs: Optional[int] = None):
|
||||
"""
|
||||
Train the model for a given number of epochs
|
||||
"""
|
||||
global_step = 0
|
||||
if epochs is None:
|
||||
epochs = self.config.num_epochs
|
||||
for epoch in range(self.first_epoch, epochs):
|
||||
global_step = self.step(epoch, global_step)
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
self.sd_pipeline.save_pretrained(save_directory)
|
||||
self.create_model_card()
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@article{prabhudesai2024aligning,
|
||||
title = {{Aligning Text-to-Image Diffusion Models with Reward Backpropagation}},
|
||||
author = {Mihir Prabhudesai and Anirudh Goyal and Deepak Pathak and Katerina Fragkiadaki},
|
||||
year = 2024,
|
||||
eprint = {arXiv:2310.03739}
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="AlignProp",
|
||||
trainer_citation=citation,
|
||||
paper_title="Aligning Text-to-Image Diffusion Models with Reward Backpropagation",
|
||||
paper_id="2310.03739",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothAlignPropTrainer(_UnslothAlignPropTrainer):
|
||||
"""
|
||||
|
||||
The AlignPropTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
||||
Note, this trainer is heavily inspired by the work here: https://github.com/mihirp1998/AlignProp/
|
||||
As of now only Stable Diffusion based pipelines are supported
|
||||
|
||||
Attributes:
|
||||
config (`AlignPropConfig`):
|
||||
Configuration object for AlignPropTrainer. Check the documentation of `PPOConfig` for more details.
|
||||
reward_function (`Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]`):
|
||||
Reward function to be used
|
||||
prompt_function (`Callable[[], tuple[str, Any]]`):
|
||||
Function to generate prompts to guide model
|
||||
sd_pipeline (`DDPOStableDiffusionPipeline`):
|
||||
Stable Diffusion pipeline to be used for training.
|
||||
image_samples_hook (`Optional[Callable[[Any, Any, Any], Any]]`):
|
||||
Hook to be called to log images
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
reward_function,
|
||||
prompt_function,
|
||||
sd_pipeline,
|
||||
image_samples_hook = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothAlignPropConfig()
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('alignprop_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
config = config,
|
||||
reward_function = reward_function,
|
||||
prompt_function = prompt_function,
|
||||
sd_pipeline = sd_pipeline,
|
||||
image_samples_hook = image_samples_hook,**kwargs)
|
||||
|
||||
pass
|
||||
1781
unsloth_compiled_cache/UnslothBCOTrainer.py
Normal file
1781
unsloth_compiled_cache/UnslothBCOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
1554
unsloth_compiled_cache/UnslothCPOTrainer.py
Normal file
1554
unsloth_compiled_cache/UnslothCPOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
872
unsloth_compiled_cache/UnslothDDPOTrainer.py
Normal file
872
unsloth_compiled_cache/UnslothDDPOTrainer.py
Normal file
@@ -0,0 +1,872 @@
|
||||
"""
|
||||
2025.6.1
|
||||
2025.6.2
|
||||
4.52.4
|
||||
0.18.2
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from trl.trainer.ddpo_trainer import (Accelerator, Any, Callable, DDPOConfig, DDPOStableDiffusionPipeline, DDPOTrainer, Optional, PerPromptStatTracker, ProjectConfiguration, PyTorchModelHubMixin, Union, defaultdict, futures, generate_model_card, get_comet_experiment_url, is_wandb_available, logger, os, set_seed, textwrap, torch, warn)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def selective_log_softmax(logits, index):
|
||||
logits = logits.to(torch.float32)
|
||||
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
||||
# loop to reduce peak mem consumption
|
||||
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
||||
return per_token_logps
|
||||
@dataclass
|
||||
class UnslothDDPOConfig(DDPOConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`DDPOTrainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
exp_name (`str`, *optional*, defaults to `os.path.basename(sys.argv[0])[: -len(".py")]`):
|
||||
Name of this experiment (by default is the file name without the extension name).
|
||||
run_name (`str`, *optional*, defaults to `""`):
|
||||
Name of this run.
|
||||
seed (`int`, *optional*, defaults to `0`):
|
||||
Random seed.
|
||||
log_with (`Literal["wandb", "tensorboard"]]` or `None`, *optional*, defaults to `None`):
|
||||
Log with either 'wandb' or 'tensorboard', check
|
||||
https://huggingface.co/docs/accelerate/usage_guides/tracking for more details.
|
||||
tracker_kwargs (`Dict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the tracker (e.g. wandb_project).
|
||||
accelerator_kwargs (`Dict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator.
|
||||
project_kwargs (`Dict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
||||
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
||||
Name of project to use for tracking.
|
||||
logdir (`str`, *optional*, defaults to `"logs"`):
|
||||
Top-level logging directory for checkpoint saving.
|
||||
num_epochs (`int`, *optional*, defaults to `100`):
|
||||
Number of epochs to train.
|
||||
save_freq (`int`, *optional*, defaults to `1`):
|
||||
Number of epochs between saving model checkpoints.
|
||||
num_checkpoint_limit (`int`, *optional*, defaults to `5`):
|
||||
Number of checkpoints to keep before overwriting old ones.
|
||||
mixed_precision (`str`, *optional*, defaults to `"fp16"`):
|
||||
Mixed precision training.
|
||||
allow_tf32 (`bool`, *optional*, defaults to `True`):
|
||||
Allow `tf32` on Ampere GPUs.
|
||||
resume_from (`str`, *optional*, defaults to `""`):
|
||||
Resume training from a checkpoint.
|
||||
sample_num_steps (`int`, *optional*, defaults to `50`):
|
||||
Number of sampler inference steps.
|
||||
sample_eta (`float`, *optional*, defaults to `1.0`):
|
||||
Eta parameter for the DDIM sampler.
|
||||
sample_guidance_scale (`float`, *optional*, defaults to `5.0`):
|
||||
Classifier-free guidance weight.
|
||||
sample_batch_size (`int`, *optional*, defaults to `1`):
|
||||
Batch size (per GPU) to use for sampling.
|
||||
sample_num_batches_per_epoch (`int`, *optional*, defaults to `2`):
|
||||
Number of batches to sample per epoch.
|
||||
train_batch_size (`int`, *optional*, defaults to `1`):
|
||||
Batch size (per GPU) to use for training.
|
||||
train_use_8bit_adam (`bool`, *optional*, defaults to `False`):
|
||||
Use 8bit Adam optimizer from bitsandbytes.
|
||||
train_learning_rate (`float`, *optional*, defaults to `3e-4`):
|
||||
Learning rate.
|
||||
train_adam_beta1 (`float`, *optional*, defaults to `0.9`):
|
||||
Adam beta1.
|
||||
train_adam_beta2 (`float`, *optional*, defaults to `0.999`):
|
||||
Adam beta2.
|
||||
train_adam_weight_decay (`float`, *optional*, defaults to `1e-4`):
|
||||
Adam weight decay.
|
||||
train_adam_epsilon (`float`, *optional*, defaults to `1e-8`):
|
||||
Adam epsilon.
|
||||
train_gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
||||
Number of gradient accumulation steps.
|
||||
train_max_grad_norm (`float`, *optional*, defaults to `1.0`):
|
||||
Maximum gradient norm for gradient clipping.
|
||||
train_num_inner_epochs (`int`, *optional*, defaults to `1`):
|
||||
Number of inner epochs per outer epoch.
|
||||
train_cfg (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier-free guidance during training.
|
||||
train_adv_clip_max (`float`, *optional*, defaults to `5.0`):
|
||||
Clip advantages to the range.
|
||||
train_clip_range (`float`, *optional*, defaults to `1e-4`):
|
||||
PPO clip range.
|
||||
train_timestep_fraction (`float`, *optional*, defaults to `1.0`):
|
||||
Fraction of timesteps to train on.
|
||||
per_prompt_stat_tracking (`bool`, *optional*, defaults to `False`):
|
||||
Whether to track statistics for each prompt separately.
|
||||
per_prompt_stat_tracking_buffer_size (`int`, *optional*, defaults to `16`):
|
||||
Number of reward values to store in the buffer for each prompt.
|
||||
per_prompt_stat_tracking_min_count (`int`, *optional*, defaults to `16`):
|
||||
Minimum number of reward values to store in the buffer.
|
||||
async_reward_computation (`bool`, *optional*, defaults to `False`):
|
||||
Whether to compute rewards asynchronously.
|
||||
max_workers (`int`, *optional*, defaults to `2`):
|
||||
Maximum number of workers to use for async reward computation.
|
||||
negative_prompts (`str`, *optional*, defaults to `""`):
|
||||
Comma-separated list of prompts to use as negative examples.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the final model checkpoint to the Hub.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
exp_name = 'test',
|
||||
run_name = '',
|
||||
seed = 3407,
|
||||
log_with = None,
|
||||
tracker_project_name = 'trl',
|
||||
logdir = 'logs',
|
||||
num_epochs = 100,
|
||||
save_freq = 1,
|
||||
num_checkpoint_limit = 5,
|
||||
mixed_precision = 'fp16',
|
||||
allow_tf32 = True,
|
||||
resume_from = '',
|
||||
sample_num_steps = 50,
|
||||
sample_eta = 1.0,
|
||||
sample_guidance_scale = 5.0,
|
||||
sample_batch_size = 1,
|
||||
sample_num_batches_per_epoch = 2,
|
||||
train_batch_size = 1,
|
||||
train_use_8bit_adam = False,
|
||||
train_learning_rate = 5e-05,
|
||||
train_adam_beta1 = 0.9,
|
||||
train_adam_beta2 = 0.999,
|
||||
train_adam_weight_decay = 0.01,
|
||||
train_adam_epsilon = 1e-08,
|
||||
train_gradient_accumulation_steps = 2,
|
||||
train_max_grad_norm = 1.0,
|
||||
train_num_inner_epochs = 1,
|
||||
train_cfg = True,
|
||||
train_adv_clip_max = 5.0,
|
||||
train_clip_range = 0.0001,
|
||||
train_timestep_fraction = 1.0,
|
||||
per_prompt_stat_tracking = False,
|
||||
per_prompt_stat_tracking_buffer_size = 16,
|
||||
per_prompt_stat_tracking_min_count = 16,
|
||||
async_reward_computation = False,
|
||||
max_workers = 2,
|
||||
negative_prompts = '',
|
||||
push_to_hub = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
exp_name = exp_name,
|
||||
run_name = run_name,
|
||||
seed = seed,
|
||||
log_with = log_with,
|
||||
tracker_project_name = tracker_project_name,
|
||||
logdir = logdir,
|
||||
num_epochs = num_epochs,
|
||||
save_freq = save_freq,
|
||||
num_checkpoint_limit = num_checkpoint_limit,
|
||||
mixed_precision = mixed_precision,
|
||||
allow_tf32 = allow_tf32,
|
||||
resume_from = resume_from,
|
||||
sample_num_steps = sample_num_steps,
|
||||
sample_eta = sample_eta,
|
||||
sample_guidance_scale = sample_guidance_scale,
|
||||
sample_batch_size = sample_batch_size,
|
||||
sample_num_batches_per_epoch = sample_num_batches_per_epoch,
|
||||
train_batch_size = train_batch_size,
|
||||
train_use_8bit_adam = train_use_8bit_adam,
|
||||
train_learning_rate = train_learning_rate,
|
||||
train_adam_beta1 = train_adam_beta1,
|
||||
train_adam_beta2 = train_adam_beta2,
|
||||
train_adam_weight_decay = train_adam_weight_decay,
|
||||
train_adam_epsilon = train_adam_epsilon,
|
||||
train_gradient_accumulation_steps = train_gradient_accumulation_steps,
|
||||
train_max_grad_norm = train_max_grad_norm,
|
||||
train_num_inner_epochs = train_num_inner_epochs,
|
||||
train_cfg = train_cfg,
|
||||
train_adv_clip_max = train_adv_clip_max,
|
||||
train_clip_range = train_clip_range,
|
||||
train_timestep_fraction = train_timestep_fraction,
|
||||
per_prompt_stat_tracking = per_prompt_stat_tracking,
|
||||
per_prompt_stat_tracking_buffer_size = per_prompt_stat_tracking_buffer_size,
|
||||
per_prompt_stat_tracking_min_count = per_prompt_stat_tracking_min_count,
|
||||
async_reward_computation = async_reward_computation,
|
||||
max_workers = max_workers,
|
||||
negative_prompts = negative_prompts,
|
||||
push_to_hub = push_to_hub,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothDDPOTrainer(PyTorchModelHubMixin):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "ddpo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DDPOConfig,
|
||||
reward_function: Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor],
|
||||
prompt_function: Callable[[], tuple[str, Any]],
|
||||
sd_pipeline: DDPOStableDiffusionPipeline,
|
||||
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
||||
):
|
||||
if image_samples_hook is None:
|
||||
warn("No image_samples_hook provided; no images will be logged")
|
||||
|
||||
self.prompt_fn = prompt_function
|
||||
self.reward_fn = reward_function
|
||||
self.config = config
|
||||
self.image_samples_callback = image_samples_hook
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
||||
|
||||
if self.config.resume_from:
|
||||
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
||||
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
||||
# get the most recent checkpoint in this directory
|
||||
checkpoints = list(
|
||||
filter(
|
||||
lambda x: "checkpoint_" in x,
|
||||
os.listdir(self.config.resume_from),
|
||||
)
|
||||
)
|
||||
if len(checkpoints) == 0:
|
||||
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
||||
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
||||
self.config.resume_from = os.path.join(
|
||||
self.config.resume_from,
|
||||
f"checkpoint_{checkpoint_numbers[-1]}",
|
||||
)
|
||||
|
||||
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
||||
|
||||
# number of timesteps within each trajectory to train on
|
||||
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with=self.config.log_with,
|
||||
mixed_precision=self.config.mixed_precision,
|
||||
project_config=accelerator_project_config,
|
||||
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
||||
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
||||
# the total number of optimizer steps to accumulate across.
|
||||
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
||||
**self.config.accelerator_kwargs,
|
||||
)
|
||||
|
||||
is_okay, message = self._config_check()
|
||||
if not is_okay:
|
||||
raise ValueError(message)
|
||||
|
||||
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.accelerator.init_trackers(
|
||||
self.config.tracker_project_name,
|
||||
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
||||
init_kwargs=self.config.tracker_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
set_seed(self.config.seed, device_specific=True)
|
||||
|
||||
self.sd_pipeline = sd_pipeline
|
||||
|
||||
self.sd_pipeline.set_progress_bar_config(
|
||||
position=1,
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
leave=False,
|
||||
desc="Timestep",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
if self.accelerator.mixed_precision == "fp16":
|
||||
inference_dtype = torch.float16
|
||||
elif self.accelerator.mixed_precision == "bf16":
|
||||
inference_dtype = torch.bfloat16
|
||||
else:
|
||||
inference_dtype = torch.float32
|
||||
|
||||
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
||||
|
||||
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if self.config.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
self.optimizer = self._setup_optimizer(
|
||||
trainable_layers.parameters() if not isinstance(trainable_layers, list) else trainable_layers
|
||||
)
|
||||
|
||||
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
||||
self.sd_pipeline.tokenizer(
|
||||
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
|
||||
if config.per_prompt_stat_tracking:
|
||||
self.stat_tracker = PerPromptStatTracker(
|
||||
config.per_prompt_stat_tracking_buffer_size,
|
||||
config.per_prompt_stat_tracking_min_count,
|
||||
)
|
||||
|
||||
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
||||
|
||||
if hasattr(self.sd_pipeline, "use_lora") and self.sd_pipeline.use_lora:
|
||||
unet, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
self.trainable_layers = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
else:
|
||||
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
|
||||
if self.config.async_reward_computation:
|
||||
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
||||
|
||||
if config.resume_from:
|
||||
logger.info(f"Resuming from {config.resume_from}")
|
||||
self.accelerator.load_state(config.resume_from)
|
||||
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
||||
else:
|
||||
self.first_epoch = 0
|
||||
|
||||
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
||||
if not is_async:
|
||||
rewards = []
|
||||
for images, prompts, prompt_metadata in prompt_image_pairs:
|
||||
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
||||
rewards.append(
|
||||
(
|
||||
torch.as_tensor(reward, device=self.accelerator.device),
|
||||
reward_metadata,
|
||||
)
|
||||
)
|
||||
else:
|
||||
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
||||
rewards = [
|
||||
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
||||
for reward, reward_metadata in rewards
|
||||
]
|
||||
|
||||
return zip(*rewards)
|
||||
|
||||
def step(self, epoch: int, global_step: int):
|
||||
"""
|
||||
Perform a single step of training.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch.
|
||||
global_step (int): The current global step.
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step.
|
||||
|
||||
"""
|
||||
samples, prompt_image_data = self._generate_samples(
|
||||
iterations=self.config.sample_num_batches_per_epoch,
|
||||
batch_size=self.config.sample_batch_size,
|
||||
)
|
||||
|
||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||
rewards, rewards_metadata = self.compute_rewards(
|
||||
prompt_image_data, is_async=self.config.async_reward_computation
|
||||
)
|
||||
|
||||
for i, image_data in enumerate(prompt_image_data):
|
||||
image_data.extend([rewards[i], rewards_metadata[i]])
|
||||
|
||||
if self.image_samples_callback is not None:
|
||||
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
||||
|
||||
rewards = torch.cat(rewards)
|
||||
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
||||
|
||||
self.accelerator.log(
|
||||
{
|
||||
"reward": rewards,
|
||||
"epoch": epoch,
|
||||
"reward_mean": rewards.mean(),
|
||||
"reward_std": rewards.std(),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if self.config.per_prompt_stat_tracking:
|
||||
# gather the prompts across processes
|
||||
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
||||
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
||||
advantages = self.stat_tracker.update(prompts, rewards)
|
||||
else:
|
||||
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||
|
||||
# ungather advantages; keep the entries corresponding to the samples on this process
|
||||
samples["advantages"] = (
|
||||
torch.as_tensor(advantages)
|
||||
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
||||
.to(self.accelerator.device)
|
||||
)
|
||||
|
||||
del samples["prompt_ids"]
|
||||
|
||||
total_batch_size, num_timesteps = samples["timesteps"].shape
|
||||
|
||||
for inner_epoch in range(self.config.train_num_inner_epochs):
|
||||
# shuffle samples along batch dimension
|
||||
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
||||
samples = {k: v[perm] for k, v in samples.items()}
|
||||
|
||||
# shuffle along time dimension independently for each sample
|
||||
# still trying to understand the code below
|
||||
perms = torch.stack(
|
||||
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
||||
)
|
||||
|
||||
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
||||
samples[key] = samples[key][
|
||||
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
||||
perms,
|
||||
]
|
||||
|
||||
original_keys = samples.keys()
|
||||
original_values = samples.values()
|
||||
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
||||
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
||||
|
||||
# Transpose the list of original values
|
||||
transposed_values = zip(*reshaped_values)
|
||||
# Create new dictionaries for each row of transposed values
|
||||
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
||||
|
||||
self.sd_pipeline.unet.train()
|
||||
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
||||
# ensure optimization step at the end of the inner epoch
|
||||
if not self.accelerator.sync_gradients:
|
||||
raise ValueError(
|
||||
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
||||
)
|
||||
|
||||
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
||||
self.accelerator.save_state()
|
||||
|
||||
return global_step
|
||||
|
||||
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
||||
"""
|
||||
Calculate the loss for a batch of an unpacked sample
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
The latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
||||
timesteps (torch.Tensor):
|
||||
The timesteps sampled from the diffusion model, shape: [batch_size]
|
||||
next_latents (torch.Tensor):
|
||||
The next latents sampled from the diffusion model, shape: [batch_size, num_channels_latents, height, width]
|
||||
log_probs (torch.Tensor):
|
||||
The log probabilities of the latents, shape: [batch_size]
|
||||
advantages (torch.Tensor):
|
||||
The advantages of the latents, shape: [batch_size]
|
||||
embeds (torch.Tensor):
|
||||
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
||||
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
||||
(all of these are of shape (1,))
|
||||
"""
|
||||
with self.autocast():
|
||||
if self.config.train_cfg:
|
||||
noise_pred = self.sd_pipeline.unet(
|
||||
torch.cat([latents] * 2),
|
||||
torch.cat([timesteps] * 2),
|
||||
embeds,
|
||||
).sample
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
else:
|
||||
noise_pred = self.sd_pipeline.unet(
|
||||
latents,
|
||||
timesteps,
|
||||
embeds,
|
||||
).sample
|
||||
# compute the log prob of next_latents given latents under the current model
|
||||
|
||||
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
||||
noise_pred,
|
||||
timesteps,
|
||||
latents,
|
||||
eta=self.config.sample_eta,
|
||||
prev_sample=next_latents,
|
||||
)
|
||||
|
||||
log_prob = scheduler_step_output.log_probs
|
||||
|
||||
advantages = torch.clamp(
|
||||
advantages,
|
||||
-self.config.train_adv_clip_max,
|
||||
self.config.train_adv_clip_max,
|
||||
)
|
||||
|
||||
ratio = torch.exp(log_prob - log_probs)
|
||||
|
||||
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
||||
|
||||
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
||||
|
||||
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
||||
|
||||
return loss, approx_kl, clipfrac
|
||||
|
||||
def loss(
|
||||
self,
|
||||
advantages: torch.Tensor,
|
||||
clip_range: float,
|
||||
ratio: torch.Tensor,
|
||||
):
|
||||
unclipped_loss = -advantages * ratio
|
||||
clipped_loss = -advantages * torch.clamp(
|
||||
ratio,
|
||||
1.0 - clip_range,
|
||||
1.0 + clip_range,
|
||||
)
|
||||
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
||||
|
||||
def _setup_optimizer(self, trainable_layers_parameters):
|
||||
if self.config.train_use_8bit_adam:
|
||||
import bitsandbytes
|
||||
|
||||
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
return optimizer_cls(
|
||||
trainable_layers_parameters,
|
||||
lr=self.config.train_learning_rate,
|
||||
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
||||
weight_decay=self.config.train_adam_weight_decay,
|
||||
eps=self.config.train_adam_epsilon,
|
||||
)
|
||||
|
||||
def _save_model_hook(self, models, weights, output_dir):
|
||||
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
||||
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
||||
|
||||
def _load_model_hook(self, models, input_dir):
|
||||
self.sd_pipeline.load_checkpoint(models, input_dir)
|
||||
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
||||
|
||||
def _generate_samples(self, iterations, batch_size):
|
||||
"""
|
||||
Generate samples from the model
|
||||
|
||||
Args:
|
||||
iterations (int): Number of iterations to generate samples for
|
||||
batch_size (int): Batch size to use for sampling
|
||||
|
||||
Returns:
|
||||
samples (list[dict[str, torch.Tensor]]), prompt_image_pairs (list[list[Any]])
|
||||
"""
|
||||
samples = []
|
||||
prompt_image_pairs = []
|
||||
self.sd_pipeline.unet.eval()
|
||||
|
||||
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
||||
|
||||
for _ in range(iterations):
|
||||
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
||||
|
||||
prompt_ids = self.sd_pipeline.tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
||||
|
||||
with self.autocast():
|
||||
sd_output = self.sd_pipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=sample_neg_prompt_embeds,
|
||||
num_inference_steps=self.config.sample_num_steps,
|
||||
guidance_scale=self.config.sample_guidance_scale,
|
||||
eta=self.config.sample_eta,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
images = sd_output.images
|
||||
latents = sd_output.latents
|
||||
log_probs = sd_output.log_probs
|
||||
|
||||
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
||||
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
||||
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"timesteps": timesteps,
|
||||
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
||||
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
||||
"log_probs": log_probs,
|
||||
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
||||
}
|
||||
)
|
||||
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
||||
|
||||
return samples, prompt_image_pairs
|
||||
|
||||
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
||||
"""
|
||||
Train on a batch of samples. Main training segment
|
||||
|
||||
Args:
|
||||
inner_epoch (int): The current inner epoch
|
||||
epoch (int): The current epoch
|
||||
global_step (int): The current global step
|
||||
batched_samples (list[dict[str, torch.Tensor]]): The batched samples to train on
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step
|
||||
"""
|
||||
info = defaultdict(list)
|
||||
for _i, sample in enumerate(batched_samples):
|
||||
if self.config.train_cfg:
|
||||
# concat negative prompts to sample prompts to avoid two forward passes
|
||||
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
||||
else:
|
||||
embeds = sample["prompt_embeds"]
|
||||
|
||||
for j in range(self.num_train_timesteps):
|
||||
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
||||
loss, approx_kl, clipfrac = self.calculate_loss(
|
||||
sample["latents"][:, j],
|
||||
sample["timesteps"][:, j],
|
||||
sample["next_latents"][:, j],
|
||||
sample["log_probs"][:, j],
|
||||
sample["advantages"],
|
||||
embeds,
|
||||
)
|
||||
info["approx_kl"].append(approx_kl)
|
||||
info["clipfrac"].append(clipfrac)
|
||||
info["loss"].append(loss)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
self.trainable_layers.parameters()
|
||||
if not isinstance(self.trainable_layers, list)
|
||||
else self.trainable_layers,
|
||||
self.config.train_max_grad_norm,
|
||||
)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if self.accelerator.sync_gradients:
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
||||
info = self.accelerator.reduce(info, reduction="mean")
|
||||
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
||||
self.accelerator.log(info, step=global_step)
|
||||
global_step += 1
|
||||
info = defaultdict(list)
|
||||
return global_step
|
||||
|
||||
def _config_check(self) -> tuple[bool, str]:
|
||||
samples_per_epoch = (
|
||||
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
||||
)
|
||||
total_train_batch_size = (
|
||||
self.config.train_batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.config.train_gradient_accumulation_steps
|
||||
)
|
||||
|
||||
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
||||
return (
|
||||
False,
|
||||
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
||||
)
|
||||
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
||||
return (
|
||||
False,
|
||||
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
||||
)
|
||||
if not samples_per_epoch % total_train_batch_size == 0:
|
||||
return (
|
||||
False,
|
||||
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
def train(self, epochs: Optional[int] = None):
|
||||
"""
|
||||
Train the model for a given number of epochs
|
||||
"""
|
||||
global_step = 0
|
||||
if epochs is None:
|
||||
epochs = self.config.num_epochs
|
||||
for epoch in range(self.first_epoch, epochs):
|
||||
global_step = self.step(epoch, global_step)
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
self.sd_pipeline.save_pretrained(save_directory)
|
||||
self.create_model_card()
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@inproceedings{black2024training,
|
||||
title = {{Training Diffusion Models with Reinforcement Learning}},
|
||||
author = {Kevin Black and Michael Janner and Yilun Du and Ilya Kostrikov and Sergey Levine},
|
||||
year = 2024,
|
||||
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
||||
publisher = {OpenReview.net},
|
||||
url = {https://openreview.net/forum?id=YCWjhGrJFD},
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="DDPO",
|
||||
trainer_citation=citation,
|
||||
paper_title="Training Diffusion Models with Reinforcement Learning",
|
||||
paper_id="2305.13301",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothDDPOTrainer(_UnslothDDPOTrainer):
|
||||
"""
|
||||
|
||||
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
||||
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
||||
As of now only Stable Diffusion based pipelines are supported
|
||||
|
||||
Attributes:
|
||||
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
||||
details.
|
||||
**reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) -- Reward function to be used
|
||||
**prompt_function** (Callable[[], tuple[str, Any]]) -- Function to generate prompts to guide model
|
||||
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
||||
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
reward_function,
|
||||
prompt_function,
|
||||
sd_pipeline,
|
||||
image_samples_hook = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothDDPOConfig()
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('ddpo_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
config = config,
|
||||
reward_function = reward_function,
|
||||
prompt_function = prompt_function,
|
||||
sd_pipeline = sd_pipeline,
|
||||
image_samples_hook = image_samples_hook,**kwargs)
|
||||
|
||||
pass
|
||||
2096
unsloth_compiled_cache/UnslothDPOTrainer.py
Normal file
2096
unsloth_compiled_cache/UnslothDPOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
832
unsloth_compiled_cache/UnslothGKDTrainer.py
Normal file
832
unsloth_compiled_cache/UnslothGKDTrainer.py
Normal file
@@ -0,0 +1,832 @@
|
||||
"""
|
||||
2025.6.1
|
||||
2025.6.2
|
||||
4.52.4
|
||||
0.18.2
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from trl.trainer.gkd_trainer import (Any, AutoModelForCausalLM, BaseImageProcessor, Callable, DataCollator, DataCollatorForChatML, Dataset, EvalPrediction, F, FeatureExtractionMixin, GKDConfig, GKDTrainer, GenerationConfig, Optional, PeftConfig, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SFTTrainer, TrainerCallback, Union, disable_dropout_in_model, empty_cache, generate_model_card, get_comet_experiment_url, is_wandb_available, nn, os, prepare_deepspeed, random, textwrap, torch, unwrap_model_for_generation)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def selective_log_softmax(logits, index):
|
||||
logits = logits.to(torch.float32)
|
||||
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
||||
# loop to reduce peak mem consumption
|
||||
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
||||
return per_token_logps
|
||||
@dataclass
|
||||
class UnslothGKDConfig(GKDConfig):
|
||||
"""
|
||||
|
||||
Configuration class for [`GKDTrainer`].
|
||||
|
||||
Args:
|
||||
temperature (`float`, *optional*, defaults to `0.9`):
|
||||
Temperature for sampling. The higher the temperature, the more random the completions.
|
||||
lmbda (`float`, *optional*, defaults to `0.5`):
|
||||
Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
|
||||
student-generated outputs).
|
||||
beta (`float`, *optional*, defaults to `0.5`):
|
||||
Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
|
||||
beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
|
||||
max_new_tokens (`int`, *optional*, defaults to `128`):
|
||||
Maximum number of tokens to generate per completion.
|
||||
teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
|
||||
Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
|
||||
being trained.
|
||||
teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
|
||||
from a string.
|
||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable dropout in the model.
|
||||
seq_kd (`bool`, *optional*, defaults to `False`):
|
||||
Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
|
||||
on teacher-generated output).
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = True,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = False,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = False,
|
||||
model_init_kwargs = None,
|
||||
dataset_text_field = 'text',
|
||||
dataset_kwargs = None,
|
||||
dataset_num_proc = None,
|
||||
eos_token = None,
|
||||
pad_token = None,
|
||||
max_length = 1024,
|
||||
packing = False,
|
||||
padding_free = False,
|
||||
pad_to_multiple_of = None,
|
||||
eval_packing = None,
|
||||
completion_only_loss = None,
|
||||
activation_offloading = False,
|
||||
max_seq_length = None,
|
||||
temperature = 0.9,
|
||||
lmbda = 0.5,
|
||||
beta = 0.5,
|
||||
max_new_tokens = 128,
|
||||
teacher_model_name_or_path = None,
|
||||
teacher_model_init_kwargs = None,
|
||||
disable_dropout = True,
|
||||
seq_kd = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = cpu_count()
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
model_init_kwargs = model_init_kwargs,
|
||||
dataset_text_field = dataset_text_field,
|
||||
dataset_kwargs = dataset_kwargs,
|
||||
dataset_num_proc = dataset_num_proc,
|
||||
eos_token = eos_token,
|
||||
pad_token = pad_token,
|
||||
max_length = max_length,
|
||||
packing = packing,
|
||||
padding_free = padding_free,
|
||||
pad_to_multiple_of = pad_to_multiple_of,
|
||||
eval_packing = eval_packing,
|
||||
completion_only_loss = completion_only_loss,
|
||||
activation_offloading = activation_offloading,
|
||||
max_seq_length = max_seq_length,
|
||||
temperature = temperature,
|
||||
lmbda = lmbda,
|
||||
beta = beta,
|
||||
max_new_tokens = max_new_tokens,
|
||||
teacher_model_name_or_path = teacher_model_name_or_path,
|
||||
teacher_model_init_kwargs = teacher_model_init_kwargs,
|
||||
disable_dropout = disable_dropout,
|
||||
seq_kd = seq_kd,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothGKDTrainer(SFTTrainer):
|
||||
_tag_names = ["trl", "gkd"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
teacher_model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
args: Optional[GKDConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None, # type: ignore
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
formatting_func: Optional[Callable] = None,
|
||||
):
|
||||
# add remove_unused_columns=False to the dataclass args
|
||||
args.remove_unused_columns = False
|
||||
data_collator = DataCollatorForChatML(tokenizer=processing_class, max_length=args.max_length)
|
||||
|
||||
super().__init__(
|
||||
model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
peft_config=peft_config,
|
||||
formatting_func=formatting_func,
|
||||
)
|
||||
|
||||
if args.teacher_model_init_kwargs is None:
|
||||
teacher_model_init_kwargs = {}
|
||||
elif not isinstance(teacher_model, str):
|
||||
raise ValueError(
|
||||
"You passed teacher_model_init_kwargs to the GKDConfig, but your teacher_model is already instantiated."
|
||||
)
|
||||
else:
|
||||
teacher_model_init_kwargs = args.teacher_model_init_kwargs
|
||||
teacher_model_init_kwargs["torch_dtype"] = (
|
||||
teacher_model_init_kwargs["torch_dtype"]
|
||||
if teacher_model_init_kwargs["torch_dtype"] in ["auto", None]
|
||||
else getattr(torch, teacher_model_init_kwargs["torch_dtype"])
|
||||
)
|
||||
|
||||
if isinstance(teacher_model, str):
|
||||
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model, **teacher_model_init_kwargs)
|
||||
|
||||
# Disable dropout in the model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(self.model)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.teacher_model = prepare_deepspeed(teacher_model, self.accelerator)
|
||||
else:
|
||||
self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True)
|
||||
|
||||
self.lmbda = args.lmbda
|
||||
self.beta = args.beta
|
||||
self.temperature = args.temperature
|
||||
self.seq_kd = args.seq_kd
|
||||
|
||||
self.generation_config = GenerationConfig(
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
temperature=args.temperature,
|
||||
do_sample=True,
|
||||
top_k=0,
|
||||
use_cache=False if args.gradient_checkpointing else True,
|
||||
pad_token_id=self.processing_class.pad_token_id,
|
||||
)
|
||||
# Set custom EOS tokens if they are specified by the model's generation
|
||||
# config. This is important for models with the Llama 3 chat template,
|
||||
# which use special tokens <|eot_id|> and <|eom_id|> to mark the end of
|
||||
# turns or messages.
|
||||
if (
|
||||
hasattr(self.model.generation_config, "eos_token_id")
|
||||
and self.model.generation_config.eos_token_id is not None
|
||||
):
|
||||
self.generation_config.eos_token_id = self.model.generation_config.eos_token_id
|
||||
|
||||
def _prepare_dataset(self, dataset, *args):
|
||||
# SFTTrainer._prepare_dataset() applies the chat template and rename the messages column to text. However, we
|
||||
# need to keep the messages column as it is. We use the following workaround to keep the messages column.
|
||||
dataset = dataset.add_column("_messages", dataset["messages"])
|
||||
dataset = super()._prepare_dataset(dataset, *args)
|
||||
dataset = dataset.rename_column("_messages", "messages")
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def generalized_jsd_loss(
|
||||
student_logits, teacher_logits, labels=None, beta=0.5, temperature=1.0, reduction="batchmean"
|
||||
):
|
||||
"""
|
||||
Compute the generalized Jensen-Shannon Divergence loss for knowledge distillation using F.kl_div. See Eq. (1)
|
||||
of https://huggingface.co/papers/2306.13649 for the definition.
|
||||
|
||||
Args:
|
||||
student_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
||||
teacher_logits: Tensor of shape (batch_size, sequence_length, vocab_size)
|
||||
labels: Tensor of shape (batch_size, sequence_length) with -100 for padding tokens to ignore when computing loss
|
||||
beta: Interpolation coefficient between 0 and 1 (default: 0.5)
|
||||
temperature: Softmax temperature (default: 1.0)
|
||||
reduction: Specifies the reduction to apply to the output (default: 'batchmean')
|
||||
|
||||
Returns:
|
||||
loss: Scalar tensor with the generalized JSD loss
|
||||
"""
|
||||
|
||||
# Apply temperature scaling
|
||||
student_logits = student_logits / temperature
|
||||
teacher_logits = teacher_logits / temperature
|
||||
|
||||
# Compute log probabilities for student and probabilities for teacher
|
||||
student_log_probs = F.log_softmax(student_logits, dim=-1)
|
||||
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
|
||||
|
||||
if beta == 0:
|
||||
jsd = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
||||
elif beta == 1:
|
||||
jsd = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
|
||||
else:
|
||||
# Compute the log of the mixture distribution
|
||||
# log(a + b) = log(exp(log(a)) + exp(log(b))) -> for mixture
|
||||
beta = torch.tensor(beta, dtype=student_log_probs.dtype)
|
||||
mixture_log_probs = torch.logsumexp(
|
||||
torch.stack([student_log_probs + torch.log(1 - beta), teacher_log_probs + torch.log(beta)]),
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Compute KL divergences using F.kl_div
|
||||
# PyTorch differs from the standard mathematical definition, so the order of the probability distributions is swapped compared to that defined in the paper.
|
||||
kl_teacher = F.kl_div(mixture_log_probs, teacher_log_probs, reduction="none", log_target=True)
|
||||
kl_student = F.kl_div(mixture_log_probs, student_log_probs, reduction="none", log_target=True)
|
||||
|
||||
# Compute the Generalized Jensen-Shannon Divergence
|
||||
jsd = beta * kl_teacher + (1 - beta) * kl_student
|
||||
|
||||
# Masking
|
||||
if labels is not None:
|
||||
mask = labels != -100
|
||||
jsd = jsd[mask]
|
||||
|
||||
# Apply reduction
|
||||
if reduction == "batchmean":
|
||||
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
|
||||
elif reduction == "sum":
|
||||
return jsd.sum()
|
||||
elif reduction == "mean":
|
||||
return jsd.mean()
|
||||
else:
|
||||
return jsd
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
# compute student output
|
||||
outputs_student = model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
# compute teacher output in eval mode
|
||||
self.teacher_model.eval()
|
||||
with torch.no_grad():
|
||||
outputs_teacher = self.teacher_model(
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
)
|
||||
|
||||
# slice the logits for the generated tokens using the inputs["prompts"] lengths
|
||||
prompt_lengths = inputs["prompts"].shape[1]
|
||||
shifted_student_logits = outputs_student.logits[:, prompt_lengths - 1 : -1, :]
|
||||
shifted_teacher_logits = outputs_teacher.logits[:, prompt_lengths - 1 : -1, :]
|
||||
shifted_labels = inputs["labels"][:, prompt_lengths:]
|
||||
|
||||
# compute loss
|
||||
loss = self.generalized_jsd_loss(
|
||||
student_logits=shifted_student_logits,
|
||||
teacher_logits=shifted_teacher_logits,
|
||||
labels=shifted_labels,
|
||||
beta=self.beta,
|
||||
)
|
||||
|
||||
# empty cache
|
||||
empty_cache()
|
||||
|
||||
# Return loss
|
||||
return (loss, outputs_student) if return_outputs else loss
|
||||
|
||||
@staticmethod
|
||||
def generate_on_policy_outputs(model, inputs, generation_config, pad_token_id=None):
|
||||
# Generate output with respect to the prompt only
|
||||
generated_outputs = model.generate(
|
||||
input_ids=inputs["prompts"],
|
||||
attention_mask=inputs.get("prompt_attention_mask", None),
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
# Get the generated token IDs
|
||||
generated_tokens = generated_outputs.sequences
|
||||
# Calculate new attention mask
|
||||
new_attention_mask = torch.ones_like(generated_tokens)
|
||||
new_labels = generated_tokens.clone()
|
||||
|
||||
# If there's pad_token_id, set attention mask to 0 for padding tokens
|
||||
if pad_token_id is not None:
|
||||
new_labels[new_labels == pad_token_id] = -100
|
||||
new_attention_mask[generated_tokens == pad_token_id] = 0
|
||||
|
||||
return generated_tokens, new_attention_mask, new_labels
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Perform a training step for the Generalized Knowledge Distillation (GKD) model.
|
||||
|
||||
This method implements the on-policy learning approach described in the GKD paper.
|
||||
With probability `self.lmbda`, it generates new responses using the student model,
|
||||
which are then used for training instead of the original inputs.
|
||||
"""
|
||||
if self.seq_kd:
|
||||
with unwrap_model_for_generation(self.teacher_model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
||||
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
if random.random() <= self.lmbda:
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
new_input_ids, new_attention_mask, new_labels = self.generate_on_policy_outputs(
|
||||
unwrapped_model, inputs, self.generation_config, self.processing_class.pad_token_id
|
||||
)
|
||||
inputs["input_ids"] = new_input_ids
|
||||
inputs["attention_mask"] = new_attention_mask
|
||||
inputs["labels"] = new_labels
|
||||
|
||||
loss = super().training_step(model, inputs, num_items_in_batch)
|
||||
return loss
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@inproceedings{agarwal2024on-policy,
|
||||
title = {{On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes}},
|
||||
author = {Rishabh Agarwal and Nino Vieillard and Yongchao Zhou and Piotr Stanczyk and Sabela Ramos Garea and Matthieu Geist and Olivier Bachem},
|
||||
year = 2024,
|
||||
booktitle = {The Twelfth International Conference on Learning Representations, {ICLR} 2024, Vienna, Austria, May 7-11, 2024},
|
||||
publisher = {OpenReview.net},
|
||||
url = {https://openreview.net/forum?id=3zKtaqxLhW},
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="GKD",
|
||||
trainer_citation=citation,
|
||||
paper_title="On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes",
|
||||
paper_id="2306.13649",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothGKDTrainer(_UnslothGKDTrainer):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
teacher_model = None,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
train_dataset = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
compute_metrics = None,
|
||||
callbacks = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
peft_config = None,
|
||||
formatting_func = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothGKDConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('gkd_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
teacher_model = teacher_model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
train_dataset = train_dataset,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
compute_metrics = compute_metrics,
|
||||
callbacks = callbacks,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
peft_config = peft_config,
|
||||
formatting_func = formatting_func,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
|
||||
pass
|
||||
2302
unsloth_compiled_cache/UnslothGRPOTrainer.py
Normal file
2302
unsloth_compiled_cache/UnslothGRPOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
912
unsloth_compiled_cache/UnslothIterativeSFTTrainer.py
Normal file
912
unsloth_compiled_cache/UnslothIterativeSFTTrainer.py
Normal file
@@ -0,0 +1,912 @@
|
||||
"""
|
||||
2025.6.1
|
||||
2025.6.2
|
||||
4.52.4
|
||||
0.18.2
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from trl.trainer.iterative_sft_trainer import (AutoModelForCausalLM, AutoTokenizer, BaseImageProcessor, Callable, DataCollator, DataCollatorForLanguageModeling, DataCollatorForSeq2Seq, DataLoader, Dataset, EvalLoopOutput, FeatureExtractionMixin, IterativeSFTConfig, IterativeSFTTrainer, Optional, PPODecorators, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainingArguments, Union, generate_model_card, get_comet_experiment_url, is_peft_available, is_wandb_available, os, torch, warnings)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def selective_log_softmax(logits, index):
|
||||
logits = logits.to(torch.float32)
|
||||
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
||||
# loop to reduce peak mem consumption
|
||||
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
||||
return per_token_logps
|
||||
@dataclass
|
||||
class UnslothIterativeSFTConfig(IterativeSFTConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`IterativeSFTTrainer`].
|
||||
|
||||
Only the parameters specific to iterative SFT training are listed here. For details on other parameters, refer to the
|
||||
[`~transformers.TrainingArguments`] documentation.
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
> Parameters that control the model
|
||||
|
||||
model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
|
||||
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||
argument of the [`IterativeSFTTrainer`] is provided as a string.
|
||||
|
||||
> Parameters that control the data preprocessing
|
||||
|
||||
max_length (`int` or `None`, *optional*, defaults to `None`):
|
||||
Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated.
|
||||
truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
|
||||
The truncation mode to use, either `"keep_end"` or `"keep_start"`.
|
||||
optimize_device_cache (`bool`, *optional*, defaults to `False`):
|
||||
Whether to optimize CUDA cache for slightly more memory-efficient training.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = True,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = False,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = False,
|
||||
model_init_kwargs = None,
|
||||
max_length = None,
|
||||
truncation_mode = 'keep_end',
|
||||
optimize_device_cache = False,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
model_init_kwargs = model_init_kwargs,
|
||||
max_length = max_length,
|
||||
truncation_mode = truncation_mode,
|
||||
optimize_device_cache = optimize_device_cache,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothIterativeSFTTrainer(Trainer):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "iterative-sft"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[str, PreTrainedModel],
|
||||
args: Optional[Union[IterativeSFTConfig, TrainingArguments]] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
# Deprecated parameters
|
||||
max_length: Optional[int] = None,
|
||||
truncation_mode: Optional[str] = None,
|
||||
optimize_device_cache: Optional[bool] = None,
|
||||
):
|
||||
# Handle deprecated parameters
|
||||
deprecated_params = {}
|
||||
if max_length is not None:
|
||||
deprecated_params["max_length"] = max_length
|
||||
warnings.warn(
|
||||
"The `max_length` parameter is deprecated and will be removed in version 0.20. "
|
||||
"Pass it through the `args` parameter using `IterativeSFTConfig(max_length=...)` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if truncation_mode is not None:
|
||||
deprecated_params["truncation_mode"] = truncation_mode
|
||||
warnings.warn(
|
||||
"The `truncation_mode` parameter is deprecated and will be removed in version 0.20. "
|
||||
"Pass it through the `args` parameter using `IterativeSFTConfig(truncation_mode=...)` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if optimize_device_cache is not None:
|
||||
deprecated_params["optimize_device_cache"] = optimize_device_cache
|
||||
warnings.warn(
|
||||
"The `optimize_device_cache` parameter is deprecated and will be removed in version 0.20 "
|
||||
"Pass it through the `args` parameter using `IterativeSFTConfig(optimize_device_cache=...)` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
|
||||
# Args
|
||||
model_id = model if isinstance(model, str) else model.config._name_or_path
|
||||
if args is None:
|
||||
model_name = model_id.split("/")[-1]
|
||||
args = IterativeSFTConfig(f"{model_name}-IterativeSFT")
|
||||
elif isinstance(args, TrainingArguments) and not isinstance(args, IterativeSFTConfig):
|
||||
dict_args = args.to_dict()
|
||||
dict_args["hub_token"] = args.hub_token # to_dict hides the hub_token
|
||||
dict_args.pop("push_to_hub_token")
|
||||
args = IterativeSFTConfig(**dict_args)
|
||||
|
||||
# Update args with deprecated parameters if provided
|
||||
if deprecated_params:
|
||||
for key, value in deprecated_params.items():
|
||||
setattr(args, key, value)
|
||||
|
||||
# Handle the tokenizer
|
||||
if processing_class is None:
|
||||
processing_class = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
# Model
|
||||
if args.model_init_kwargs is not None and not isinstance(model, str):
|
||||
warnings.warn(
|
||||
"You passed model_init_kwargs to the `IterativeSFTConfig`, but your model is already instantiated. "
|
||||
"The `model_init_kwargs` will be ignored."
|
||||
)
|
||||
if isinstance(model, str):
|
||||
model = self._create_model_from_path(model, args)
|
||||
|
||||
# PEFT configuration and model wrapping
|
||||
if is_peft_available() and isinstance(model, PeftModel):
|
||||
self.is_peft_model = True
|
||||
else:
|
||||
self.is_peft_model = False
|
||||
|
||||
self.processing_class = processing_class
|
||||
self.is_encoder_decoder = getattr(model.config, "is_encoder_decoder", False)
|
||||
|
||||
if data_collator is None:
|
||||
if self.is_encoder_decoder:
|
||||
self.data_collator = DataCollatorForSeq2Seq(
|
||||
processing_class, label_pad_token_id=-100, pad_to_multiple_of=8
|
||||
)
|
||||
else:
|
||||
self.data_collator = DataCollatorForLanguageModeling(self.processing_class, mlm=False)
|
||||
else:
|
||||
self.data_collator = data_collator
|
||||
|
||||
self.max_length = args.max_length
|
||||
self.truncation_mode = args.truncation_mode
|
||||
self.optimize_device_cache = args.optimize_device_cache
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=self.data_collator,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
compute_metrics=compute_metrics,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
self.create_optimizer_and_scheduler(self.args.max_steps)
|
||||
|
||||
# prepare model, optimizer and lr_scheduler
|
||||
self.model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
|
||||
self.model, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
|
||||
self.processing_class.truncation_side = "left" if self.truncation_mode == "keep_end" else "right"
|
||||
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError(
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
PPODecorators.optimize_device_cache = self.optimize_device_cache
|
||||
|
||||
def _create_model_from_path(self, model_path: str, args: IterativeSFTConfig) -> PreTrainedModel:
|
||||
"""Creates a model from a path or model identifier."""
|
||||
model_init_kwargs = args.model_init_kwargs or {}
|
||||
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
|
||||
|
||||
def prepare_model_inputs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: torch.Tensor):
|
||||
if attention_mask is None:
|
||||
attention_mask = [torch.ones_like(ids) for ids in input_ids]
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
input_data = self.data_collator(
|
||||
[
|
||||
{"input_ids": ids, "attention_mask": att, "labels": lab}
|
||||
for ids, att, lab in zip(input_ids, attention_mask, labels)
|
||||
]
|
||||
).to(self.model.device)
|
||||
|
||||
input_data.pop("decoder_input_ids", None) # This is directly computed inside the model
|
||||
|
||||
input_data["labels"][input_data["labels"] == self.processing_class.pad_token_id] = -100
|
||||
|
||||
else:
|
||||
input_data = self.data_collator(
|
||||
[{"input_ids": ids, "attention_mask": att} for ids, att in zip(input_ids, attention_mask)]
|
||||
).to(self.model.device)
|
||||
|
||||
# truncate in case the user has provided input_ids, attention_mask and labels
|
||||
if self.max_length is not None:
|
||||
if self.truncation_mode == "keep_start":
|
||||
input_data = {k: v[: self.max_length] for k, v in input_data.items()}
|
||||
elif self.truncation_mode == "keep_end":
|
||||
input_data = {k: v[-self.max_length :] for k, v in input_data.items()}
|
||||
else:
|
||||
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
||||
|
||||
return input_data
|
||||
|
||||
@staticmethod
|
||||
def _step_safety_checker(
|
||||
input_ids: list[torch.LongTensor],
|
||||
attention_mask: list[torch.LongTensor],
|
||||
labels: list[torch.LongTensor],
|
||||
texts: list[str],
|
||||
texts_labels: list[str],
|
||||
):
|
||||
"""
|
||||
Check if the input data is valid for training.
|
||||
|
||||
Args:
|
||||
input_ids (list[`torch.LongTensor`]):
|
||||
List of tensors containing the input_ids
|
||||
attention_mask (list[`torch.LongTensor`]):
|
||||
List of tensors containing the attention_mask
|
||||
labels (list[`torch.FloatTensor`]):
|
||||
List of tensors containing the labels
|
||||
texts (list[`str`]):
|
||||
List of string containing the text input.
|
||||
texts_labels (list[`str`]):
|
||||
List of string containing the text labels.
|
||||
|
||||
Returns:
|
||||
`tuple`: The input data.
|
||||
"""
|
||||
if texts is None:
|
||||
if attention_mask is None:
|
||||
for name, tensor_list in zip(["input_ids", "labels"], [input_ids, labels]):
|
||||
if not isinstance(tensor_list, list):
|
||||
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
|
||||
if not isinstance(tensor_list[0], torch.Tensor):
|
||||
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
|
||||
else:
|
||||
for name, tensor_list in zip(
|
||||
["input_ids", "attention_mask", "labels"], [input_ids, attention_mask, labels]
|
||||
):
|
||||
if not isinstance(tensor_list, list):
|
||||
raise ValueError(f"{name} must be a list of tensors - got {type(tensor_list)}")
|
||||
if not isinstance(tensor_list[0], torch.Tensor):
|
||||
raise ValueError(f"Elements in {name} must be tensors - got {type(tensor_list[0])}")
|
||||
else:
|
||||
if not isinstance(texts, list):
|
||||
raise ValueError(f"'text' must be a list of strings - got {type(texts)}")
|
||||
if not isinstance(texts[0], str):
|
||||
raise ValueError(f"Elements in 'text' must be strings - got {type(texts[0])}")
|
||||
if texts_labels is not None:
|
||||
if not isinstance(texts_labels, list):
|
||||
raise ValueError(f"'text_labels' must be a list of strings - got {type(texts_labels)}")
|
||||
if not isinstance(texts_labels[0], str):
|
||||
raise ValueError(f"Elements in 'text_labels' must be strings - got {type(texts_labels[0])}")
|
||||
|
||||
return input_ids, attention_mask, labels, texts, texts_labels
|
||||
|
||||
@PPODecorators.empty_device_cache()
|
||||
def step(
|
||||
self,
|
||||
input_ids: Optional[list[torch.LongTensor]] = None,
|
||||
attention_mask: Optional[list[torch.LongTensor]] = None,
|
||||
labels: Optional[list[torch.LongTensor]] = None,
|
||||
texts: Optional[list[str]] = None,
|
||||
texts_labels: Optional[list[str]] = None,
|
||||
):
|
||||
"""
|
||||
Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.
|
||||
Args:
|
||||
input_ids (list[`torch.LongTensor`]):
|
||||
List of tensors containing the input_ids (if not provided, text will be used)
|
||||
attention_mask (list[`torch.LongTensor`], , *optional*):
|
||||
List of tensors containing the attention_mask
|
||||
labels (list[`torch.FloatTensor`], *optional*):
|
||||
List of tensors containing the labels (if set to None, will default to input_ids)
|
||||
texts (list[`str`], *optional*):
|
||||
List of strings containing the text input (if not provided, input_ids will directly be used)
|
||||
texts_labels (list[`str`], *optional*):
|
||||
List of strings containing the text labels (if set to None, will default to text)
|
||||
|
||||
Returns:
|
||||
`dict[str, Any]`: A summary of the training statistics
|
||||
"""
|
||||
self.model.train()
|
||||
|
||||
if self.state.global_step == 0:
|
||||
self.tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
||||
if input_ids is None and texts is None:
|
||||
raise ValueError("Step should include `input_ids` or `texts` as keyword arguments.")
|
||||
elif input_ids is not None and texts is not None:
|
||||
warnings.warn(
|
||||
"Both `input_ids` and `texts` argument are provided. `input_ids` will be ignored. "
|
||||
"Please provide only one of the two.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
if labels is None and texts_labels is None and self.is_encoder_decoder:
|
||||
raise ValueError(
|
||||
"No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed."
|
||||
)
|
||||
|
||||
input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker(
|
||||
input_ids, attention_mask, labels, texts, texts_labels
|
||||
)
|
||||
|
||||
if texts is not None:
|
||||
model_inputs = self.processing_class(
|
||||
texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
input_ids, attention_mask = model_inputs["input_ids"], model_inputs["attention_mask"]
|
||||
|
||||
if texts_labels is not None:
|
||||
labels = self.processing_class(
|
||||
texts, max_length=self.max_length, truncation=True, padding=True, return_tensors="pt"
|
||||
)["input_ids"]
|
||||
|
||||
if labels is None:
|
||||
labels = input_ids
|
||||
|
||||
model_inputs = self.prepare_model_inputs(input_ids, attention_mask, labels)
|
||||
|
||||
model_inputs_names = list(model_inputs.keys())
|
||||
|
||||
batch_dict = {}
|
||||
batch_dict.update(model_inputs)
|
||||
|
||||
def collator(data):
|
||||
return_dict = dict()
|
||||
for key in data[0]:
|
||||
if key in ["input_ids", "attention_mask", "labels"]:
|
||||
return_dict[key] = torch.stack([d[key] for d in data]).to(self.model.device)
|
||||
return return_dict
|
||||
|
||||
batch_data = Dataset.from_dict(batch_dict)
|
||||
batch_data.set_format("torch")
|
||||
|
||||
step_dataloader = DataLoader(
|
||||
batch_data,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collator,
|
||||
)
|
||||
|
||||
for _, batch in enumerate(step_dataloader):
|
||||
with self.accelerator.accumulate(self.model):
|
||||
model_inputs = {k: batch[k] for k in model_inputs_names}
|
||||
loss = self.compute_loss(self.model, model_inputs)
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean()
|
||||
|
||||
tr_loss_step = loss.detach()
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
|
||||
if self.accelerator.sync_gradients and self.args.max_grad_norm is not None:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
self.model.parameters(),
|
||||
self.args.max_grad_norm,
|
||||
)
|
||||
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
self.state.global_step += 1
|
||||
|
||||
# update stats etc
|
||||
self.tr_loss += tr_loss_step
|
||||
|
||||
self._maybe_log_save_evaluate()
|
||||
|
||||
def _maybe_log_save_evaluate(self):
|
||||
# check if eval is required
|
||||
if self.args.eval_steps is not None:
|
||||
if self.state.global_step % self.args.eval_steps == 0 and self.state.global_step != 0:
|
||||
self.evaluate(self.eval_dataset)
|
||||
|
||||
# check if logging is required
|
||||
if self.args.logging_steps is not None:
|
||||
if self.state.global_step % self.args.logging_steps == 0 and self.state.global_step != 0:
|
||||
logs: dict[str, float] = {}
|
||||
|
||||
tr_loss_scalar = self._nested_gather(self.tr_loss).mean().item()
|
||||
|
||||
# reset tr_loss to zero
|
||||
self.tr_loss -= self.tr_loss
|
||||
|
||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||
logs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
self._globalstep_last_logged = self.state.global_step
|
||||
|
||||
self.log(logs)
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="Iterative SFT",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothIterativeSFTTrainer(_UnslothIterativeSFTTrainer):
|
||||
"""
|
||||
|
||||
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
|
||||
|
||||
Args:
|
||||
model (`Union[str, PreTrainedModel]`):
|
||||
Model to be trained. Can be either:
|
||||
|
||||
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
|
||||
a path to a *directory* containing model weights saved using
|
||||
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
|
||||
loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
|
||||
in `args.model_init_kwargs`.
|
||||
- A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
|
||||
args ([`IterativeSFTConfig`], *optional*, defaults to `None`):
|
||||
Configuration for this trainer. If `None`, a default configuration is used.
|
||||
data_collator (`DataCollator`, *optional*):
|
||||
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
||||
Will default to [`~transformers.default_data_collator`] if no `processing_class` is provided, an instance
|
||||
of [`~transformers.DataCollatorWithPadding`] otherwise if the processing_class is a feature extractor or
|
||||
tokenizer.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
|
||||
Processing class used to process the data. If `None`, the processing class is loaded from the model's name
|
||||
with [`~transformers.AutoTokenizer.from_pretrained`].
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values.
|
||||
max_length (`int`, *optional*, deprecated):
|
||||
Maximum length of the tokenized sequence. Use `args.max_length` instead.
|
||||
truncation_mode (`str`, *optional*, deprecated):
|
||||
The truncation mode to use. Use `args.truncation_mode` instead.
|
||||
optimize_device_cache (`bool`, *optional*, deprecated):
|
||||
Whether to optimize CUDA cache. Use `args.optimize_device_cache` instead.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
compute_metrics = None,
|
||||
max_length = None,
|
||||
truncation_mode = None,
|
||||
optimize_device_cache = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothIterativeSFTConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('iterative_sft_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
compute_metrics = compute_metrics,
|
||||
max_length = max_length,
|
||||
truncation_mode = truncation_mode,
|
||||
optimize_device_cache = optimize_device_cache,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
|
||||
pass
|
||||
1989
unsloth_compiled_cache/UnslothKTOTrainer.py
Normal file
1989
unsloth_compiled_cache/UnslothKTOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
971
unsloth_compiled_cache/UnslothNashMDTrainer.py
Normal file
971
unsloth_compiled_cache/UnslothNashMDTrainer.py
Normal file
@@ -0,0 +1,971 @@
|
||||
"""
|
||||
2025.6.1
|
||||
2025.6.2
|
||||
4.52.4
|
||||
0.18.2
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from trl.trainer.nash_md_trainer import (Any, BaseImageProcessor, BasePairwiseJudge, Callable, Dataset, EvalPrediction, F, FeatureExtractionMixin, GeometricMixtureWrapper, IterableDataset, NashMDConfig, NashMDTrainer, OnlineDPOTrainer, OptimizerNames, Optional, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, SIMPLE_CHAT_TEMPLATE, TrainerCallback, Union, empty_cache, generate_model_card, get_comet_experiment_url, get_reward, is_conversational, is_peft_available, is_wandb_available, jinja2, maybe_apply_chat_template, nn, os, textwrap, torch, truncate_right, unwrap_model_for_generation)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def selective_log_softmax(logits, index):
|
||||
logits = logits.to(torch.float32)
|
||||
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
||||
# loop to reduce peak mem consumption
|
||||
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
||||
return per_token_logps
|
||||
@dataclass
|
||||
class UnslothNashMDConfig(NashMDConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`NashMDTrainer`].
|
||||
|
||||
Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
|
||||
|
||||
Parameters:
|
||||
mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
|
||||
Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
|
||||
mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
|
||||
epochs.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = True,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = False,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = False,
|
||||
reward_model_path = None,
|
||||
judge = None,
|
||||
max_new_tokens = 64,
|
||||
max_length = 512,
|
||||
temperature = 0.9,
|
||||
missing_eos_penalty = None,
|
||||
loss_type = 'sigmoid',
|
||||
dataset_num_proc = None,
|
||||
disable_dropout = True,
|
||||
use_vllm = False,
|
||||
gpu_memory_utilization = 0.55,
|
||||
ds3_gather_for_generation = True,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = cpu_count()
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
reward_model_path = reward_model_path,
|
||||
judge = judge,
|
||||
max_new_tokens = max_new_tokens,
|
||||
max_length = max_length,
|
||||
temperature = temperature,
|
||||
missing_eos_penalty = missing_eos_penalty,
|
||||
loss_type = loss_type,
|
||||
dataset_num_proc = dataset_num_proc,
|
||||
disable_dropout = disable_dropout,
|
||||
use_vllm = use_vllm,
|
||||
gpu_memory_utilization = gpu_memory_utilization,
|
||||
ds3_gather_for_generation = ds3_gather_for_generation,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothNashMDTrainer(OnlineDPOTrainer):
|
||||
r""""""
|
||||
|
||||
_tag_names = ["trl", "nash-md"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
||||
reward_model: Union[PreTrainedModel, nn.Module, None] = None,
|
||||
judge: Optional[BasePairwiseJudge] = None,
|
||||
args: Optional[NashMDConfig] = None,
|
||||
data_collator: Optional[Callable] = None,
|
||||
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
reward_model=reward_model,
|
||||
judge=judge,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
reward_processing_class=processing_class, # for now, NashMDTrainer can't use any reward model
|
||||
peft_config=peft_config,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
self._mixture_coef = self.args.mixture_coef
|
||||
|
||||
# Overwrite the stats dictionary to include NashMD specific statistics
|
||||
self.stats = {
|
||||
# Remove "non_score_reward", "rlhf_reward", "scores_margin"
|
||||
# Add "mixture_coef"
|
||||
"loss/kl": [],
|
||||
"objective/entropy": [],
|
||||
"loss/score": [],
|
||||
"rewards/probabilities": [],
|
||||
"rewards/accuracies": [],
|
||||
"rewards/margins": [],
|
||||
"logps/chosen": [],
|
||||
"logps/rejected": [],
|
||||
"val/model_contain_eos_token": [],
|
||||
"val/ref_contain_eos_token": [],
|
||||
"beta": [],
|
||||
"mixture_coef": [],
|
||||
}
|
||||
if self.reward_model is not None:
|
||||
self.stats["rewards/chosen"] = []
|
||||
self.stats["rewards/rejected"] = []
|
||||
|
||||
@property
|
||||
def mixture_coef(self):
|
||||
if isinstance(self._mixture_coef, list):
|
||||
epoch = self.state.epoch
|
||||
return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1]
|
||||
else:
|
||||
return self._mixture_coef
|
||||
|
||||
def _generate_completions(self, model, prompts):
|
||||
# Generate completions from the policy model.
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_policy_for_gen_ctx:
|
||||
model_output = unwrapped_policy_for_gen_ctx.generate(
|
||||
input_ids=prompts["input_ids"],
|
||||
attention_mask=prompts["attention_mask"],
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
|
||||
# Get the DDP/FSDP unwrapped version of the main model.
|
||||
# This will be the policy model for GeometricMixtureWrapper (PEFT adapters active if PEFT is used).
|
||||
policy_model_for_gmw = self.accelerator.unwrap_model(model)
|
||||
|
||||
# Determine the correct reference model for GeometricMixtureWrapper.
|
||||
# This also needs to be DDP/FSDP unwrapped.
|
||||
ref_model_for_gmw: torch.nn.Module
|
||||
if self.ref_model is None:
|
||||
# No explicit ref_model is provided.
|
||||
# Use the base of the main `model` if it's a PEFT model.
|
||||
# policy_model_for_gmw is already DDP-unwrapped.
|
||||
if is_peft_available() and isinstance(policy_model_for_gmw, PeftModel):
|
||||
ref_model_for_gmw = policy_model_for_gmw.get_base_model()
|
||||
else:
|
||||
# Not a PEFT model (or PEFT not available), or already a base model.
|
||||
# Use the DDP-unwrapped policy model itself as the reference.
|
||||
ref_model_for_gmw = policy_model_for_gmw
|
||||
else:
|
||||
# An explicit ref_model is provided. Unwrap it for DDP/FSDP.
|
||||
ref_model_for_gmw = self.accelerator.unwrap_model(self.ref_model)
|
||||
|
||||
# Both models given to GeometricMixtureWrapper (policy_model_for_gmw and ref_model_for_gmw) are DDP-unwrapped.
|
||||
with torch.no_grad(): # Ensure no_grad context for mixture model generation
|
||||
mixture_model = GeometricMixtureWrapper(
|
||||
model=policy_model_for_gmw,
|
||||
ref_model=ref_model_for_gmw,
|
||||
generation_config=self.generation_config,
|
||||
mixture_coef=self.mixture_coef,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
|
||||
mixture_output = mixture_model.generate(
|
||||
input_ids=prompts["input_ids"],
|
||||
attention_mask=prompts["attention_mask"],
|
||||
generation_config=self.generation_config,
|
||||
)
|
||||
|
||||
return model_output, mixture_output
|
||||
|
||||
def _process_completions(self, model_output, mixture_output, prompts):
|
||||
context_length = prompts["input_ids"].shape[1]
|
||||
|
||||
# Process model completions
|
||||
model_completion_ids = model_output[:, context_length:]
|
||||
model_completion_ids, model_completion_mask = truncate_right(
|
||||
model_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
||||
)
|
||||
model_data = {
|
||||
"input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1),
|
||||
"attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1),
|
||||
"raw": prompts["raw"],
|
||||
}
|
||||
|
||||
# Process reference model completions
|
||||
mixture_completion_ids = mixture_output[:, context_length:]
|
||||
mixture_completion_ids, mixture_completion_mask = truncate_right(
|
||||
mixture_completion_ids, self.processing_class.eos_token_id, self.processing_class.pad_token_id
|
||||
)
|
||||
mixture_data = {
|
||||
"input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1),
|
||||
"attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1),
|
||||
"raw": prompts["raw"],
|
||||
}
|
||||
|
||||
return model_data, mixture_data
|
||||
|
||||
def _compute_rewards(self, model_data, mixture_data, context_length):
|
||||
with torch.no_grad():
|
||||
_, model_scores, _ = get_reward(
|
||||
self.reward_model, model_data["input_ids"], self.processing_class.pad_token_id, context_length
|
||||
)
|
||||
_, mixture_scores, _ = get_reward(
|
||||
self.reward_model, mixture_data["input_ids"], self.processing_class.pad_token_id, context_length
|
||||
)
|
||||
|
||||
# Apply EOS penalty if needed
|
||||
if self.args.missing_eos_penalty is not None:
|
||||
model_contain_eos = torch.any(model_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
||||
mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.processing_class.eos_token_id, dim=-1)
|
||||
model_scores[~model_contain_eos] -= self.args.missing_eos_penalty
|
||||
mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty
|
||||
|
||||
return model_scores, mixture_scores
|
||||
|
||||
def _compute_judge(self, model_data, mixture_data, context_length):
|
||||
prompts = model_data["raw"]
|
||||
model_data_completions = self.processing_class.batch_decode(
|
||||
model_data["input_ids"][:, context_length:], skip_special_tokens=True
|
||||
)
|
||||
model_data_completions = [completion.strip() for completion in model_data_completions]
|
||||
|
||||
mixture_data_completions = self.processing_class.batch_decode(
|
||||
mixture_data["input_ids"][:, context_length:], skip_special_tokens=True
|
||||
)
|
||||
mixture_data_completions = [completion.strip() for completion in mixture_data_completions]
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
model_data_completions = [
|
||||
[{"role": "assistant", "content": completion}] for completion in model_data_completions
|
||||
]
|
||||
environment = jinja2.Environment()
|
||||
template = environment.from_string(SIMPLE_CHAT_TEMPLATE)
|
||||
prompts = [template.render(messages=message) for message in prompts]
|
||||
model_data_completions = [template.render(messages=completion) for completion in model_data_completions]
|
||||
|
||||
mixture_data_completions = [
|
||||
[{"role": "assistant", "content": completion}] for completion in mixture_data_completions
|
||||
]
|
||||
mixture_data_completions = [
|
||||
template.render(messages=completion) for completion in mixture_data_completions
|
||||
]
|
||||
|
||||
probability = self.judge.judge(
|
||||
prompts,
|
||||
list(zip(model_data_completions, mixture_data_completions)),
|
||||
return_scores=True,
|
||||
)
|
||||
return torch.tensor(probability, device=model_data["input_ids"].device)
|
||||
|
||||
def _compute_logprobs(self, model, model_data, context_length):
|
||||
def compute_logprobs_for_data(m, data):
|
||||
output = m(data["input_ids"], attention_mask=data["attention_mask"])
|
||||
logits = output.logits[:, context_length - 1 : -1]
|
||||
token_logprobs = selective_log_softmax(logits, data["input_ids"][:, context_length:])
|
||||
return token_logprobs
|
||||
|
||||
# Compute logprobs for model completions under the model
|
||||
model_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
||||
|
||||
# Compute logprobs of model completions under the reference model
|
||||
with torch.no_grad():
|
||||
if self.ref_model is None:
|
||||
with model.disable_adapter():
|
||||
ref_logprobs_model_data = compute_logprobs_for_data(model, model_data)
|
||||
else:
|
||||
ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data)
|
||||
|
||||
# Mask padding tokens
|
||||
model_padding_mask = model_data["attention_mask"][:, context_length:] == 0
|
||||
model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
||||
ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0)
|
||||
|
||||
return (model_logprobs_model_data, ref_logprobs_model_data)
|
||||
|
||||
def _compute_losses(
|
||||
self,
|
||||
model_logprobs_model_data,
|
||||
ref_logprobs_model_data,
|
||||
probability,
|
||||
):
|
||||
# reinforce score where 0.5 is a control variate
|
||||
score = (probability - 0.5) * model_logprobs_model_data.sum(1)
|
||||
|
||||
# kl divergence via reinforce
|
||||
with torch.no_grad():
|
||||
log_ratio = model_logprobs_model_data - ref_logprobs_model_data
|
||||
kl_div_log = log_ratio.sum(1)
|
||||
kl_div_loss = (log_ratio * model_logprobs_model_data).sum(1)
|
||||
|
||||
# final loss
|
||||
loss = self.beta * kl_div_loss - score
|
||||
|
||||
return loss.mean(), score, kl_div_log
|
||||
|
||||
def _log_statistics(
|
||||
self,
|
||||
model_data,
|
||||
mixture_data,
|
||||
model_logprobs_model_data,
|
||||
ref_logprobs_model_data,
|
||||
probability,
|
||||
score,
|
||||
kl_div,
|
||||
context_length,
|
||||
model_scores=None,
|
||||
mixture_scores=None,
|
||||
):
|
||||
# Helper function to gather and compute mean
|
||||
def gather_mean(tensor):
|
||||
return self.accelerator.gather_for_metrics(tensor).mean().item()
|
||||
|
||||
# Log score
|
||||
self.stats["loss/score"].append(gather_mean(score))
|
||||
# Log KL divergence
|
||||
self.stats["loss/kl"].append(gather_mean(kl_div))
|
||||
|
||||
# Log logprobs
|
||||
model_logprobs_model_data_sum = model_logprobs_model_data.sum(1)
|
||||
ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1)
|
||||
|
||||
self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum))
|
||||
self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum))
|
||||
|
||||
# Log rewards
|
||||
if self.reward_model is not None:
|
||||
self.stats["rewards/chosen"].append(gather_mean(model_scores))
|
||||
self.stats["rewards/rejected"].append(gather_mean(mixture_scores))
|
||||
|
||||
# Log probabilities
|
||||
self.stats["rewards/probabilities"].append(gather_mean(probability))
|
||||
|
||||
# Calculate entropy for model data
|
||||
entropy_model_data = -model_logprobs_model_data.sum(1)
|
||||
self.stats["objective/entropy"].append(gather_mean(entropy_model_data))
|
||||
|
||||
# Calculate margins
|
||||
margin = model_logprobs_model_data_sum - ref_logprobs_model_data_sum
|
||||
self.stats["rewards/margins"].append(gather_mean(margin))
|
||||
|
||||
# Calculate accuracy
|
||||
accuracy = (margin > 0).float()
|
||||
self.stats["rewards/accuracies"].append(gather_mean(accuracy))
|
||||
|
||||
# Log EOS token statistics
|
||||
model_eos = (model_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
||||
mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.processing_class.eos_token_id).any(dim=1)
|
||||
self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float()))
|
||||
self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float()))
|
||||
|
||||
# Log beta and mixture coef
|
||||
self.stats["beta"].append(self.beta)
|
||||
self.stats["mixture_coef"].append(self.mixture_coef)
|
||||
|
||||
def training_step(
|
||||
self, model: nn.Module, inputs: dict[str, Union[torch.Tensor, Any]], num_items_in_batch: Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
model.train()
|
||||
|
||||
# Apply chat template and tokenize the input
|
||||
batch_size = len(next(iter(inputs.values())))
|
||||
prompts = inputs["prompt"]
|
||||
inputs = [{k: v[i] for k, v in inputs.items()} for i in range(batch_size)]
|
||||
inputs = [maybe_apply_chat_template(x, self.processing_class) for x in inputs]
|
||||
inputs = [self.tokenize_row(x, self.model.config.is_encoder_decoder, self.processing_class) for x in inputs]
|
||||
inputs = self.data_collator(inputs)
|
||||
|
||||
# need the prompt_ only
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
context_length = inputs["prompt_input_ids"].shape[1]
|
||||
prompts = {
|
||||
"input_ids": inputs["prompt_input_ids"],
|
||||
"attention_mask": inputs["prompt_attention_mask"],
|
||||
"raw": prompts,
|
||||
}
|
||||
del inputs
|
||||
|
||||
# Sample completions from both the model and the reference model
|
||||
model_output, mixture_output = self._generate_completions(model, prompts)
|
||||
|
||||
# Process model completions
|
||||
model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts)
|
||||
|
||||
# Compute rewards
|
||||
if self.reward_model is not None:
|
||||
model_scores, mixture_scores = self._compute_rewards(model_data, mixture_data, context_length)
|
||||
# probability of the model data vs the mixture data
|
||||
probability = F.sigmoid(model_scores - mixture_scores)
|
||||
else:
|
||||
model_scores, mixture_scores = None, None
|
||||
probability = self._compute_judge(model_data, mixture_data, context_length)
|
||||
|
||||
# Compute logprobs
|
||||
model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length)
|
||||
|
||||
# Compute loss
|
||||
loss, score, kl_div = self._compute_losses(model_logprobs_model_data, ref_logprobs_model_data, probability)
|
||||
|
||||
# Log everything
|
||||
self._log_statistics(
|
||||
model_data,
|
||||
mixture_data,
|
||||
model_logprobs_model_data.detach(),
|
||||
ref_logprobs_model_data,
|
||||
probability,
|
||||
score.detach(),
|
||||
kl_div.detach(),
|
||||
context_length,
|
||||
model_scores,
|
||||
mixture_scores,
|
||||
)
|
||||
|
||||
if (
|
||||
self.args.torch_empty_cache_steps is not None
|
||||
and self.state.global_step % self.args.torch_empty_cache_steps == 0
|
||||
):
|
||||
empty_cache()
|
||||
|
||||
kwargs = {}
|
||||
# For LOMO optimizers you need to explicitly use the learning rate
|
||||
if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
|
||||
kwargs["learning_rate"] = self._get_learning_rate()
|
||||
|
||||
if self.args.n_gpu > 1:
|
||||
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
||||
|
||||
if self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
scaled_loss.backward()
|
||||
else:
|
||||
self.accelerator.backward(loss, **kwargs)
|
||||
|
||||
return loss.detach() / self.args.gradient_accumulation_steps
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@inproceedings{munos2024nash,
|
||||
title = {{Nash Learning from Human Feedback}},
|
||||
author = {R{\'{e}}mi Munos and Michal Valko and Daniele Calandriello and Mohammad Gheshlaghi Azar and Mark Rowland and Zhaohan Daniel Guo and Yunhao Tang and Matthieu Geist and Thomas Mesnard and C{\\^{o}}me Fiegel and Andrea Michi and Marco Selvi and Sertan Girgin and Nikola Momchev and Olivier Bachem and Daniel J. Mankowitz and Doina Precup and Bilal Piot},
|
||||
year = 2024,
|
||||
booktitle = {Forty-first International Conference on Machine Learning, {ICML} 2024, Vienna, Austria, July 21-27, 2024},
|
||||
publisher = {OpenReview.net},
|
||||
url = {https://openreview.net/forum?id=Y5AmNYiyCQ}
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="Nash-MD",
|
||||
trainer_citation=citation,
|
||||
paper_title="Nash Learning from Human Feedback",
|
||||
paper_id="2312.00886",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothNashMDTrainer(_UnslothNashMDTrainer):
|
||||
"""
|
||||
|
||||
Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`].
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForCausalLM`.
|
||||
ref_model (`PreTrainedModelWrapper`):
|
||||
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
||||
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
||||
reward_model (`transformers.PreTrainedModel`):
|
||||
The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
|
||||
judge (`BasePairwiseJudge`):
|
||||
The judge to use for pairwise comparison of model completions.
|
||||
args (`NashMDConfig`):
|
||||
The NashMD config arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
||||
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
peft_config (`dict`):
|
||||
The peft config to use for training.
|
||||
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
||||
a dictionary string to metric values.
|
||||
callbacks (`list[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
ref_model = None,
|
||||
reward_model = None,
|
||||
judge = None,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
train_dataset = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
peft_config = None,
|
||||
compute_metrics = None,
|
||||
callbacks = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothNashMDConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('nash_md_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
ref_model = ref_model,
|
||||
reward_model = reward_model,
|
||||
judge = judge,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
train_dataset = train_dataset,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
peft_config = peft_config,
|
||||
compute_metrics = compute_metrics,
|
||||
callbacks = callbacks,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
|
||||
pass
|
||||
1506
unsloth_compiled_cache/UnslothORPOTrainer.py
Normal file
1506
unsloth_compiled_cache/UnslothORPOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
1276
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
Normal file
1276
unsloth_compiled_cache/UnslothOnlineDPOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
1267
unsloth_compiled_cache/UnslothPPOTrainer.py
Normal file
1267
unsloth_compiled_cache/UnslothPPOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
792
unsloth_compiled_cache/UnslothPRMTrainer.py
Normal file
792
unsloth_compiled_cache/UnslothPRMTrainer.py
Normal file
@@ -0,0 +1,792 @@
|
||||
"""
|
||||
2025.6.1
|
||||
2025.6.2
|
||||
4.52.4
|
||||
0.18.2
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from trl.trainer.prm_trainer import (BaseImageProcessor, Callable, DataCollator, DataCollatorForTokenClassification, Dataset, EvalPrediction, FeatureExtractionMixin, Optional, PRMConfig, PRMTrainer, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, Trainer, TrainerCallback, Union, chain, compute_accuracy, disable_dropout_in_model, features, generate_model_card, inspect, is_peft_available, is_wandb_available, nn, os, prepare_model_for_kbit_training, textwrap, torch, warnings)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def selective_log_softmax(logits, index):
|
||||
logits = logits.to(torch.float32)
|
||||
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
||||
# loop to reduce peak mem consumption
|
||||
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
||||
return per_token_logps
|
||||
@dataclass
|
||||
class UnslothPRMConfig(PRMConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`PRMTrainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
learning_rate (`float`, *optional*, defaults to `1e-5`):
|
||||
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
|
||||
[`~transformers.TrainingArguments`].
|
||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the sequences (prompt + completion) used for truncation.
|
||||
max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
|
||||
Maximum length of the prompt used for truncation.
|
||||
max_completion_length (`int` or `None`, *optional*, defaults to `None`):
|
||||
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
|
||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable dropout in the model.
|
||||
step_separator (`str`, *optional*, defaults to `"\n"`):
|
||||
Separator used to separate each step of the reasoning process.
|
||||
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to train only on the last step.
|
||||
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = True,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = False,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = False,
|
||||
max_length = 1024,
|
||||
max_prompt_length = 512,
|
||||
max_completion_length = None,
|
||||
disable_dropout = True,
|
||||
step_separator = '\
|
||||
',
|
||||
train_on_last_step_only = False,
|
||||
dataset_num_proc = None,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = cpu_count()
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
max_length = max_length,
|
||||
max_prompt_length = max_prompt_length,
|
||||
max_completion_length = max_completion_length,
|
||||
disable_dropout = disable_dropout,
|
||||
step_separator = step_separator,
|
||||
train_on_last_step_only = train_on_last_step_only,
|
||||
dataset_num_proc = dataset_num_proc,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothPRMTrainer(Trainer):
|
||||
""""""
|
||||
|
||||
_tag_names = ["trl", "prm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
args: Optional[PRMConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
):
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
if not isinstance(model, PeftModel):
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
||||
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
warnings.warn(
|
||||
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
||||
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
||||
)
|
||||
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
|
||||
model = model
|
||||
|
||||
# Disable dropout in the model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
if compute_metrics is None:
|
||||
compute_metrics = compute_accuracy
|
||||
|
||||
if data_collator is None:
|
||||
if processing_class is None:
|
||||
raise ValueError(
|
||||
"A processing_class must be specified when using the default DataCollatorForTokenClassification"
|
||||
)
|
||||
data_collator = DataCollatorForTokenClassification(processing_class, max_length=args.max_length)
|
||||
|
||||
if "input_ids" not in train_dataset.column_names:
|
||||
with PartialState().main_process_first():
|
||||
fn_kwargs = {
|
||||
"tokenizer": processing_class,
|
||||
"step_separator": args.step_separator,
|
||||
"max_length": args.max_length,
|
||||
"max_prompt_length": args.max_prompt_length,
|
||||
"max_completion_length": args.max_completion_length,
|
||||
"train_on_last_step_only": args.train_on_last_step_only,
|
||||
}
|
||||
train_fn_kwargs = {**fn_kwargs, "is_eval": False}
|
||||
train_dataset = train_dataset.map(
|
||||
self.tokenize_row,
|
||||
fn_kwargs=train_fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=train_dataset.features,
|
||||
desc="Tokenizing train dataset",
|
||||
features=features.Features( # needed to avoid map to cast labels to bool
|
||||
{
|
||||
"labels": features.Sequence(features.Value("int64")),
|
||||
"input_ids": features.Sequence(features.Value("int64")),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
eval_fn_kwargs = {**fn_kwargs, "is_eval": True}
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
self.tokenize_row,
|
||||
fn_kwargs=eval_fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=eval_dataset.features,
|
||||
desc="Tokenizing eval dataset",
|
||||
features=features.Features( # needed to avoid map to cast labels to bool
|
||||
{
|
||||
"labels": features.Sequence(features.Value("int64")),
|
||||
"input_ids": features.Sequence(features.Value("int64")),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
@staticmethod
|
||||
def tokenize_row(
|
||||
features,
|
||||
tokenizer,
|
||||
step_separator,
|
||||
max_length,
|
||||
max_prompt_length,
|
||||
max_completion_length,
|
||||
train_on_last_step_only,
|
||||
is_eval,
|
||||
):
|
||||
r"""
|
||||
Tokenize a row of the dataset.
|
||||
|
||||
Args:
|
||||
features (`dict[str, str]`):
|
||||
Row of the dataset, should contain the keys `"prompt"`, `"completions"`, and `"labels"`.
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
Tokenizer used to process the data.
|
||||
step_separator (`str`):
|
||||
Separator between steps in the completion.
|
||||
max_length (`int` or `None`):
|
||||
Maximum length of the sequences (prompt + completion). If `None`, the sequences are not truncated.
|
||||
max_prompt_length (`int` or `None`):
|
||||
Maximum length of the prompt. If `None`, the prompt is not truncated.
|
||||
max_completion_length (`int` or `None`):
|
||||
Maximum length of the completion sequences. If `None`, the completion sequences are not truncated.
|
||||
train_on_last_step_only (`bool`):
|
||||
Whether to train only on the last step. If `True`, the labels are `-100` for all tokens except the last
|
||||
token of the completion.
|
||||
is_eval (`bool`):
|
||||
Whether the function is used to tokenize samples from a training or an evaluation dataset. Used only if `train_on_last_step_only` is set to `True`.
|
||||
|
||||
Returns:
|
||||
`dict[str, list[int]]`:
|
||||
Tokenized sequences with the keys `"input_ids"`, and `"labels".
|
||||
|
||||
Example:
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
|
||||
>>> features = {"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
... "completions": ["11 is greater than 8.",
|
||||
... "Hence, 9.11 > 9.8."],
|
||||
... "labels": [True, False]}
|
||||
>>> PRMTrainer.tokenize_row(features, tokenizer, "\n", max_completion_length=None, train_on_last_step_only=False, is_eval=False)
|
||||
{'input_ids': [23085, 1372, 374, 8131, 11, 220, 24, 13, 23, 476, 220, 24, 13, 16, 16, 30, 16, 16, 374, 7046, 1091, 220, 23, 13, 198, 39, 763, 11, 220, 24, 13, 16, 16, 861, 220, 24, 13, 23, 13, 198],
|
||||
'labels': [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0]}
|
||||
```
|
||||
"""
|
||||
# Tokenize the prompt and completions
|
||||
prompt_ids = tokenizer(features["prompt"], add_special_tokens=False)["input_ids"]
|
||||
completions_ids = [
|
||||
tokenizer(completion, add_special_tokens=False)["input_ids"] for completion in features["completions"]
|
||||
]
|
||||
if train_on_last_step_only and not is_eval:
|
||||
labels = [-100] * (len(features["labels"]) - 1) + [int(features["labels"][-1])]
|
||||
else:
|
||||
labels = [int(label) for label in features["labels"]]
|
||||
|
||||
# Get the ID of the separator token and add it to the completions
|
||||
separator_ids = tokenizer.encode(step_separator, add_special_tokens=False)
|
||||
completions_ids = [completion + separator_ids for completion in completions_ids]
|
||||
|
||||
# Create the label
|
||||
labels = [[-100] * (len(completion) - 1) + [label] for completion, label in zip(completions_ids, labels)]
|
||||
|
||||
# Join the completions and labels steps
|
||||
completion_ids = list(chain(*completions_ids))
|
||||
labels = list(chain(*labels))
|
||||
|
||||
if tokenizer.bos_token_id is not None:
|
||||
prompt_ids = [tokenizer.bos_token_id] + prompt_ids
|
||||
|
||||
# Truncate prompt and completion sequences
|
||||
if max_prompt_length is not None:
|
||||
prompt_ids = prompt_ids[-max_prompt_length:]
|
||||
if max_completion_length is not None:
|
||||
completion_ids = completion_ids[:max_completion_length]
|
||||
labels = labels[:max_completion_length]
|
||||
|
||||
input_ids = prompt_ids + completion_ids
|
||||
labels = [-100] * len(prompt_ids) + labels
|
||||
|
||||
if max_length is not None:
|
||||
input_ids = input_ids[:max_length]
|
||||
labels = labels[:max_length]
|
||||
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@article{uesato2022solving,
|
||||
title = {{Solving Math Word Problems With Process- and Outcome-Based Feedback}},
|
||||
author = {Uesato, Jonathan and Kushman, Nate and Kumar, Ramana and Song, Francis and Siegel, Noah and Wang, Lisa and Creswell, Antonia and Irving, Geoffrey and Higgins, Irina},
|
||||
year = 2022,
|
||||
journal = {arXiv preprint arXiv:2211.14275}
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
trainer_name="PRM",
|
||||
trainer_citation=citation,
|
||||
paper_title="Solving math word problems with process-and outcome-based feedback",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothPRMTrainer(_UnslothPRMTrainer):
|
||||
"""
|
||||
|
||||
Initialize PRMTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForTokenClassification`.
|
||||
args (`PRMConfig`):
|
||||
The arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator (`DataCollatorForTokenClassification`) will be used
|
||||
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
||||
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
||||
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
||||
callbacks (`list[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
peft_config (`dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
train_dataset = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
model_init = None,
|
||||
compute_metrics = None,
|
||||
callbacks = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
peft_config = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothPRMConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('prm_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
train_dataset = train_dataset,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
model_init = model_init,
|
||||
compute_metrics = compute_metrics,
|
||||
callbacks = callbacks,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
peft_config = peft_config,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
|
||||
pass
|
||||
1126
unsloth_compiled_cache/UnslothRLOOTrainer.py
Normal file
1126
unsloth_compiled_cache/UnslothRLOOTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
812
unsloth_compiled_cache/UnslothRewardTrainer.py
Normal file
812
unsloth_compiled_cache/UnslothRewardTrainer.py
Normal file
@@ -0,0 +1,812 @@
|
||||
"""
|
||||
2025.6.1
|
||||
2025.6.2
|
||||
4.52.4
|
||||
0.18.2
|
||||
__UNSLOTH_VERSIONING__
|
||||
"""
|
||||
from torch import Tensor
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from trl.trainer.reward_trainer import (Any, BaseImageProcessor, Callable, DataCollator, Dataset, EvalPrediction, FeatureExtractionMixin, FrozenInstanceError, Optional, PartialState, PeftModel, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, RewardConfig, RewardDataCollatorWithPadding, RewardTrainer, Trainer, TrainerCallback, Union, _tokenize, compute_accuracy, decode_and_strip_padding, defaultdict, disable_dropout_in_model, gather_object, generate_model_card, get_comet_experiment_url, inspect, is_peft_available, is_rich_available, is_wandb_available, log_table_to_comet_experiment, maybe_apply_chat_template, nested_detach, nn, os, pd, prepare_model_for_kbit_training, print_rich_table, replace, torch, warnings)
|
||||
|
||||
|
||||
import os
|
||||
from typing import *
|
||||
from dataclasses import dataclass, field
|
||||
from packaging.version import Version
|
||||
import torch
|
||||
import numpy as np
|
||||
from contextlib import nullcontext
|
||||
from torch.nn import functional as F
|
||||
from transformers import DataCollatorForSeq2Seq, DataCollatorForLanguageModeling as TransformersDataCollatorForLanguageModeling
|
||||
|
||||
torch_compile_options = {
|
||||
"epilogue_fusion" : True,
|
||||
"max_autotune" : False,
|
||||
"shape_padding" : True,
|
||||
"trace.enabled" : False,
|
||||
"triton.cudagraphs" : False,
|
||||
}
|
||||
|
||||
@torch.compile(dynamic = True, fullgraph = True, options = torch_compile_options,)
|
||||
def selective_log_softmax(logits, index):
|
||||
logits = logits.to(torch.float32)
|
||||
selected_logits = torch.gather(logits, dim = -1, index = index.unsqueeze(-1)).squeeze(-1)
|
||||
# loop to reduce peak mem consumption
|
||||
# logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
|
||||
logsumexp_values = torch.logsumexp(logits, dim = -1)
|
||||
per_token_logps = selected_logits - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
|
||||
return per_token_logps
|
||||
@dataclass
|
||||
class UnslothRewardConfig(RewardConfig):
|
||||
"""
|
||||
|
||||
Configuration class for the [`RewardTrainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
|
||||
limit. This argument is required if you want to use the default data collator.
|
||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||
Whether to disable dropout in the model.
|
||||
dataset_num_proc (`int`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
center_rewards_coefficient (`float`, *optional*, defaults to `None`):
|
||||
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
||||
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
||||
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
||||
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
|
||||
the dataset is pretokenized.
|
||||
|
||||
"""
|
||||
vllm_sampling_params: Optional[Any] = field(
|
||||
default = None,
|
||||
metadata = {'help': 'vLLM SamplingParams'},
|
||||
)
|
||||
unsloth_num_chunks : Optional[int] = field(
|
||||
default = -1,
|
||||
metadata = {'help': 'Chunk size to reduce memory usage. -1 is most efficient.'},
|
||||
)
|
||||
def __init__(
|
||||
self,
|
||||
output_dir = None,
|
||||
overwrite_output_dir = None,
|
||||
do_train = False,
|
||||
do_eval = False,
|
||||
do_predict = False,
|
||||
eval_strategy = 'no',
|
||||
prediction_loss_only = False,
|
||||
per_device_train_batch_size = 4,
|
||||
per_device_eval_batch_size = 4,
|
||||
per_gpu_train_batch_size = None,
|
||||
per_gpu_eval_batch_size = None,
|
||||
gradient_accumulation_steps = 2,
|
||||
eval_accumulation_steps = 2,
|
||||
eval_delay = 0,
|
||||
torch_empty_cache_steps = 250,
|
||||
learning_rate = 5e-05,
|
||||
weight_decay = 0.01,
|
||||
adam_beta1 = 0.9,
|
||||
adam_beta2 = 0.999,
|
||||
adam_epsilon = 1e-08,
|
||||
max_grad_norm = 1.0,
|
||||
num_train_epochs = 3.0,
|
||||
max_steps = -1,
|
||||
lr_scheduler_type = 'linear',
|
||||
warmup_ratio = 0.1,
|
||||
warmup_steps = 0,
|
||||
log_level = 'passive',
|
||||
log_level_replica = 'warning',
|
||||
log_on_each_node = True,
|
||||
logging_dir = None,
|
||||
logging_strategy = 'steps',
|
||||
logging_first_step = False,
|
||||
logging_steps = 1,
|
||||
logging_nan_inf_filter = False,
|
||||
save_strategy = 'steps',
|
||||
save_steps = 500,
|
||||
save_total_limit = None,
|
||||
save_safetensors = True,
|
||||
save_on_each_node = False,
|
||||
save_only_model = False,
|
||||
restore_callback_states_from_checkpoint = False,
|
||||
no_cuda = False,
|
||||
use_cpu = False,
|
||||
use_mps_device = False,
|
||||
seed = 3407,
|
||||
data_seed = 3407,
|
||||
jit_mode_eval = False,
|
||||
use_ipex = False,
|
||||
bf16 = False,
|
||||
fp16 = False,
|
||||
fp16_opt_level = 'O1',
|
||||
half_precision_backend = 'auto',
|
||||
bf16_full_eval = False,
|
||||
fp16_full_eval = False,
|
||||
tf32 = None,
|
||||
local_rank = -1,
|
||||
ddp_backend = None,
|
||||
tpu_num_cores = None,
|
||||
tpu_metrics_debug = False,
|
||||
debug = '',
|
||||
dataloader_drop_last = False,
|
||||
eval_steps = None,
|
||||
dataloader_num_workers = 0,
|
||||
dataloader_prefetch_factor = None,
|
||||
past_index = -1,
|
||||
run_name = None,
|
||||
disable_tqdm = None,
|
||||
remove_unused_columns = False,
|
||||
label_names = None,
|
||||
load_best_model_at_end = False,
|
||||
metric_for_best_model = None,
|
||||
greater_is_better = None,
|
||||
ignore_data_skip = False,
|
||||
fsdp = '',
|
||||
fsdp_min_num_params = 0,
|
||||
fsdp_config = None,
|
||||
fsdp_transformer_layer_cls_to_wrap = None,
|
||||
accelerator_config = None,
|
||||
deepspeed = None,
|
||||
label_smoothing_factor = 0.0,
|
||||
optim = 'adamw_8bit',
|
||||
optim_args = None,
|
||||
adafactor = False,
|
||||
group_by_length = False,
|
||||
length_column_name = 'length',
|
||||
report_to = None,
|
||||
ddp_find_unused_parameters = None,
|
||||
ddp_bucket_cap_mb = None,
|
||||
ddp_broadcast_buffers = None,
|
||||
dataloader_pin_memory = True,
|
||||
dataloader_persistent_workers = False,
|
||||
skip_memory_metrics = True,
|
||||
use_legacy_prediction_loop = False,
|
||||
push_to_hub = False,
|
||||
resume_from_checkpoint = None,
|
||||
hub_model_id = None,
|
||||
hub_strategy = 'every_save',
|
||||
hub_token = None,
|
||||
hub_private_repo = None,
|
||||
hub_always_push = False,
|
||||
gradient_checkpointing = False,
|
||||
gradient_checkpointing_kwargs = None,
|
||||
include_inputs_for_metrics = False,
|
||||
eval_do_concat_batches = True,
|
||||
fp16_backend = 'auto',
|
||||
push_to_hub_model_id = None,
|
||||
push_to_hub_organization = None,
|
||||
push_to_hub_token = None,
|
||||
mp_parameters = '',
|
||||
auto_find_batch_size = False,
|
||||
full_determinism = False,
|
||||
torchdynamo = None,
|
||||
ray_scope = 'last',
|
||||
ddp_timeout = 1800,
|
||||
torch_compile = False,
|
||||
torch_compile_backend = None,
|
||||
torch_compile_mode = None,
|
||||
include_tokens_per_second = False,
|
||||
include_num_input_tokens_seen = False,
|
||||
neftune_noise_alpha = None,
|
||||
optim_target_modules = None,
|
||||
batch_eval_metrics = False,
|
||||
eval_on_start = False,
|
||||
use_liger_kernel = False,
|
||||
eval_use_gather_object = False,
|
||||
average_tokens_across_devices = False,
|
||||
max_length = 1024,
|
||||
disable_dropout = True,
|
||||
dataset_num_proc = None,
|
||||
center_rewards_coefficient = None,
|
||||
vllm_sampling_params = None,
|
||||
unsloth_num_chunks = -1,
|
||||
**kwargs,
|
||||
):
|
||||
if learning_rate < 1e-7: raise FloatingPointError(f'Unsloth: Your learning rate of `{learning_rate}` is too small and less than 1e-7! Consider increasing it, otherwise gradient updates will be close to 0!')
|
||||
if learning_rate > 1: raise OverflowError(f'Unsloth: Your learning rate of `{learning_rate}` is way too larger > 1! Consider decreasing it to 1e-1, otherwise gradient updates will explode!')
|
||||
if output_dir is None and save_strategy == 'steps' and save_steps == 500:
|
||||
output_dir = 'unsloth_training_checkpoints'
|
||||
save_strategy = 'no'
|
||||
if dataset_num_proc is None:
|
||||
from multiprocessing import cpu_count
|
||||
dataset_num_proc = cpu_count()
|
||||
|
||||
super().__init__(
|
||||
output_dir = output_dir,
|
||||
overwrite_output_dir = overwrite_output_dir,
|
||||
do_train = do_train,
|
||||
do_eval = do_eval,
|
||||
do_predict = do_predict,
|
||||
eval_strategy = eval_strategy,
|
||||
prediction_loss_only = prediction_loss_only,
|
||||
per_device_train_batch_size = per_device_train_batch_size,
|
||||
per_device_eval_batch_size = per_device_eval_batch_size,
|
||||
per_gpu_train_batch_size = per_gpu_train_batch_size,
|
||||
per_gpu_eval_batch_size = per_gpu_eval_batch_size,
|
||||
gradient_accumulation_steps = gradient_accumulation_steps,
|
||||
eval_accumulation_steps = eval_accumulation_steps,
|
||||
eval_delay = eval_delay,
|
||||
torch_empty_cache_steps = torch_empty_cache_steps,
|
||||
learning_rate = learning_rate,
|
||||
weight_decay = weight_decay,
|
||||
adam_beta1 = adam_beta1,
|
||||
adam_beta2 = adam_beta2,
|
||||
adam_epsilon = adam_epsilon,
|
||||
max_grad_norm = max_grad_norm,
|
||||
num_train_epochs = num_train_epochs,
|
||||
max_steps = max_steps,
|
||||
lr_scheduler_type = lr_scheduler_type,
|
||||
warmup_ratio = warmup_ratio,
|
||||
warmup_steps = warmup_steps,
|
||||
log_level = log_level,
|
||||
log_level_replica = log_level_replica,
|
||||
log_on_each_node = log_on_each_node,
|
||||
logging_dir = logging_dir,
|
||||
logging_strategy = logging_strategy,
|
||||
logging_first_step = logging_first_step,
|
||||
logging_steps = logging_steps,
|
||||
logging_nan_inf_filter = logging_nan_inf_filter,
|
||||
save_strategy = save_strategy,
|
||||
save_steps = save_steps,
|
||||
save_total_limit = save_total_limit,
|
||||
save_safetensors = save_safetensors,
|
||||
save_on_each_node = save_on_each_node,
|
||||
save_only_model = save_only_model,
|
||||
restore_callback_states_from_checkpoint = restore_callback_states_from_checkpoint,
|
||||
no_cuda = no_cuda,
|
||||
use_cpu = use_cpu,
|
||||
use_mps_device = use_mps_device,
|
||||
seed = seed,
|
||||
data_seed = data_seed,
|
||||
jit_mode_eval = jit_mode_eval,
|
||||
use_ipex = use_ipex,
|
||||
bf16 = bf16,
|
||||
fp16 = fp16,
|
||||
fp16_opt_level = fp16_opt_level,
|
||||
half_precision_backend = half_precision_backend,
|
||||
bf16_full_eval = bf16_full_eval,
|
||||
fp16_full_eval = fp16_full_eval,
|
||||
tf32 = tf32,
|
||||
local_rank = local_rank,
|
||||
ddp_backend = ddp_backend,
|
||||
tpu_num_cores = tpu_num_cores,
|
||||
tpu_metrics_debug = tpu_metrics_debug,
|
||||
debug = debug,
|
||||
dataloader_drop_last = dataloader_drop_last,
|
||||
eval_steps = eval_steps,
|
||||
dataloader_num_workers = dataloader_num_workers,
|
||||
dataloader_prefetch_factor = dataloader_prefetch_factor,
|
||||
past_index = past_index,
|
||||
run_name = run_name,
|
||||
disable_tqdm = disable_tqdm,
|
||||
remove_unused_columns = remove_unused_columns,
|
||||
label_names = label_names,
|
||||
load_best_model_at_end = load_best_model_at_end,
|
||||
metric_for_best_model = metric_for_best_model,
|
||||
greater_is_better = greater_is_better,
|
||||
ignore_data_skip = ignore_data_skip,
|
||||
fsdp = fsdp,
|
||||
fsdp_min_num_params = fsdp_min_num_params,
|
||||
fsdp_config = fsdp_config,
|
||||
fsdp_transformer_layer_cls_to_wrap = fsdp_transformer_layer_cls_to_wrap,
|
||||
accelerator_config = accelerator_config,
|
||||
deepspeed = deepspeed,
|
||||
label_smoothing_factor = label_smoothing_factor,
|
||||
optim = optim,
|
||||
optim_args = optim_args,
|
||||
adafactor = adafactor,
|
||||
group_by_length = group_by_length,
|
||||
length_column_name = length_column_name,
|
||||
report_to = report_to,
|
||||
ddp_find_unused_parameters = ddp_find_unused_parameters,
|
||||
ddp_bucket_cap_mb = ddp_bucket_cap_mb,
|
||||
ddp_broadcast_buffers = ddp_broadcast_buffers,
|
||||
dataloader_pin_memory = dataloader_pin_memory,
|
||||
dataloader_persistent_workers = dataloader_persistent_workers,
|
||||
skip_memory_metrics = skip_memory_metrics,
|
||||
use_legacy_prediction_loop = use_legacy_prediction_loop,
|
||||
push_to_hub = push_to_hub,
|
||||
resume_from_checkpoint = resume_from_checkpoint,
|
||||
hub_model_id = hub_model_id,
|
||||
hub_strategy = hub_strategy,
|
||||
hub_token = hub_token,
|
||||
hub_private_repo = hub_private_repo,
|
||||
hub_always_push = hub_always_push,
|
||||
gradient_checkpointing = gradient_checkpointing,
|
||||
gradient_checkpointing_kwargs = gradient_checkpointing_kwargs,
|
||||
include_inputs_for_metrics = include_inputs_for_metrics,
|
||||
eval_do_concat_batches = eval_do_concat_batches,
|
||||
fp16_backend = fp16_backend,
|
||||
push_to_hub_model_id = push_to_hub_model_id,
|
||||
push_to_hub_organization = push_to_hub_organization,
|
||||
push_to_hub_token = push_to_hub_token,
|
||||
mp_parameters = mp_parameters,
|
||||
auto_find_batch_size = auto_find_batch_size,
|
||||
full_determinism = full_determinism,
|
||||
torchdynamo = torchdynamo,
|
||||
ray_scope = ray_scope,
|
||||
ddp_timeout = ddp_timeout,
|
||||
torch_compile = torch_compile,
|
||||
torch_compile_backend = torch_compile_backend,
|
||||
torch_compile_mode = torch_compile_mode,
|
||||
include_tokens_per_second = include_tokens_per_second,
|
||||
include_num_input_tokens_seen = include_num_input_tokens_seen,
|
||||
neftune_noise_alpha = neftune_noise_alpha,
|
||||
optim_target_modules = optim_target_modules,
|
||||
batch_eval_metrics = batch_eval_metrics,
|
||||
eval_on_start = eval_on_start,
|
||||
use_liger_kernel = use_liger_kernel,
|
||||
eval_use_gather_object = eval_use_gather_object,
|
||||
average_tokens_across_devices = average_tokens_across_devices,
|
||||
max_length = max_length,
|
||||
disable_dropout = disable_dropout,
|
||||
dataset_num_proc = dataset_num_proc,
|
||||
center_rewards_coefficient = center_rewards_coefficient,**kwargs)
|
||||
self.vllm_sampling_params = vllm_sampling_params
|
||||
self.unsloth_num_chunks = unsloth_num_chunks
|
||||
pass
|
||||
|
||||
class _UnslothRewardTrainer(Trainer):
|
||||
_tag_names = ["trl", "reward-trainer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
args: Optional[RewardConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
||||
None,
|
||||
None,
|
||||
),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize RewardTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
args (`RewardConfig`):
|
||||
The arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator (`RewardDataCollatorWithPadding`) will be used
|
||||
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
|
||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
||||
reuse the fine-tuned model.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
||||
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional* defaults to `compute_accuracy`):
|
||||
The metrics to use for evaluation. If no metrics are specified, the default metric (`compute_accuracy`) will be used.
|
||||
callbacks (`list[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
peft_config (`dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
||||
"""
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
if not isinstance(model, PeftModel):
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
|
||||
_supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
warnings.warn(
|
||||
"You passed `gradient_checkpointing_kwargs` in the trainer's kwargs, but your peft version does not support it. "
|
||||
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`.",
|
||||
UserWarning,
|
||||
)
|
||||
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
|
||||
model = model
|
||||
|
||||
# Disable dropout in the model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
if compute_metrics is None:
|
||||
compute_metrics = compute_accuracy
|
||||
|
||||
if data_collator is None:
|
||||
if processing_class is None:
|
||||
raise ValueError(
|
||||
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
|
||||
)
|
||||
|
||||
max_length = args.max_length
|
||||
|
||||
data_collator = RewardDataCollatorWithPadding(processing_class)
|
||||
|
||||
if args.remove_unused_columns:
|
||||
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
||||
args.remove_unused_columns = False
|
||||
except FrozenInstanceError:
|
||||
args = replace(args, remove_unused_columns=False)
|
||||
# warn users
|
||||
warnings.warn(
|
||||
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
||||
" we have set it for you, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.use_reward_data_collator = True
|
||||
else:
|
||||
self.use_reward_data_collator = False
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
|
||||
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
||||
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
||||
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
||||
# issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
if "input_ids_chosen" not in train_dataset.column_names:
|
||||
with PartialState().main_process_first():
|
||||
fn_kwargs = {"tokenizer": processing_class}
|
||||
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
|
||||
train_dataset = train_dataset.map(
|
||||
_tokenize,
|
||||
batched=True,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
# This filter is important because otherwise you get samples that exceed the model's context length and
|
||||
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
||||
# user might get surprised if N samples are missing from training.
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
|
||||
)
|
||||
eval_dataset = eval_dataset.map(
|
||||
_tokenize,
|
||||
fn_kwargs=fn_kwargs,
|
||||
batched=True,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
# This filter is important because otherwise you get samples that exceed the model's context length and
|
||||
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
||||
# user might get surprised if N samples are missing from training.
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["input_ids_chosen"]) <= max_length
|
||||
and len(x["input_ids_rejected"]) <= max_length,
|
||||
num_proc=args.dataset_num_proc,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
return_outputs=False,
|
||||
num_items_in_batch=None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
||||
rewards_chosen = model(
|
||||
input_ids=inputs["input_ids_chosen"],
|
||||
attention_mask=inputs["attention_mask_chosen"],
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
rewards_rejected = model(
|
||||
input_ids=inputs["input_ids_rejected"],
|
||||
attention_mask=inputs["attention_mask_rejected"],
|
||||
return_dict=True,
|
||||
)["logits"]
|
||||
# calculate loss, optionally modulate with margin
|
||||
if "margin" in inputs:
|
||||
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
||||
else:
|
||||
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
|
||||
|
||||
if self.args.center_rewards_coefficient is not None:
|
||||
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
||||
|
||||
if return_outputs:
|
||||
return loss, {
|
||||
"rewards_chosen": rewards_chosen,
|
||||
"rewards_rejected": rewards_rejected,
|
||||
}
|
||||
return loss
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[list[str]] = None,
|
||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
if ignore_keys is None:
|
||||
if hasattr(self.model, "config"):
|
||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
with torch.no_grad():
|
||||
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
loss = loss.detach()
|
||||
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
||||
logits = nested_detach(logits)
|
||||
# Stack accepted against rejected, mean over logits
|
||||
# and softmax to get preferences between accepted and rejected to sum to 1
|
||||
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
||||
|
||||
labels = torch.zeros(logits.shape[0])
|
||||
labels = self._prepare_inputs(labels)
|
||||
|
||||
return loss, logits, labels
|
||||
|
||||
def evaluate(self, *args, **kwargs):
|
||||
num_print_samples = kwargs.pop("num_print_samples", 4)
|
||||
self.visualize_samples(num_print_samples)
|
||||
return super().evaluate(*args, **kwargs)
|
||||
|
||||
def visualize_samples(self, num_print_samples: int):
|
||||
"""
|
||||
Visualize the reward model logits prediction
|
||||
|
||||
Args:
|
||||
num_print_samples (`int`, defaults to `4`):
|
||||
The number of samples to print. Set to `-1` to print all samples.
|
||||
"""
|
||||
eval_dataloader = self.get_eval_dataloader()
|
||||
table = defaultdict(list)
|
||||
for _, inputs in enumerate(eval_dataloader):
|
||||
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
||||
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
|
||||
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
|
||||
table["chosen_text"].extend(gather_object(chosen_text))
|
||||
table["rejected_text"].extend(gather_object(rejected_text))
|
||||
table["logits"].extend(
|
||||
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
|
||||
)
|
||||
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
|
||||
break
|
||||
df = pd.DataFrame(table)
|
||||
if self.accelerator.process_index == 0:
|
||||
if is_rich_available():
|
||||
print_rich_table(df[:num_print_samples])
|
||||
if "wandb" in self.args.report_to:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
if "comet_ml" in self.args.report_to:
|
||||
log_table_to_comet_experiment(
|
||||
name="completions.csv",
|
||||
table=df,
|
||||
)
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, list[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the model.
|
||||
dataset_name (`str` or `None`, *optional*, defaults to `None`):
|
||||
Name of the dataset used for training.
|
||||
tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
comet_url=get_comet_experiment_url(),
|
||||
trainer_name="Reward",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
class UnslothRewardTrainer(_UnslothRewardTrainer):
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model = None,
|
||||
args = None,
|
||||
data_collator = None,
|
||||
train_dataset = None,
|
||||
eval_dataset = None,
|
||||
processing_class = None,
|
||||
model_init = None,
|
||||
compute_metrics = None,
|
||||
callbacks = None,
|
||||
preprocess_logits_for_metrics = None,
|
||||
peft_config = None,
|
||||
**kwargs
|
||||
):
|
||||
if args is None: args = UnslothRewardConfig()
|
||||
use_bf16 = getattr(args, 'bf16', False)
|
||||
use_fp16 = getattr(args, 'fp16', False)
|
||||
force_float32 = False
|
||||
if os.environ.get('UNSLOTH_FORCE_FLOAT32', '0') == '1':
|
||||
print('Unsloth: Switching to float32 training since model cannot work with float16')
|
||||
force_float32 = True
|
||||
mixed_precision_dtype = os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32')
|
||||
dtype = getattr(model.config, 'torch_dtype', None)
|
||||
if dtype is None: dtype = model.get_input_embeddings().dtype
|
||||
from unsloth_zoo.utils import _get_dtype
|
||||
dtype = _get_dtype(dtype)
|
||||
float16 = dtype == torch.float16
|
||||
if not force_float32 and (float16 and use_bf16): raise TypeError('Unsloth: Model is in float16 precision but you want to use bfloat16 precision. Set fp16 to `True` and bf16 to `False`')
|
||||
if not force_float32 and (not float16 and use_fp16): raise TypeError('Unsloth: Model is in bfloat16 precision but you want to use float16 precision. Set fp16 to `False` and bf16 to `True`')
|
||||
if force_float32:
|
||||
args.fp16 = False
|
||||
args.bf16 = False
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'no'
|
||||
elif (not use_bf16 and not use_fp16) and mixed_precision_dtype == 'float32':
|
||||
args.fp16 = float16
|
||||
args.bf16 = not float16
|
||||
os.environ['ACCELERATE_MIXED_PRECISION'] = 'fp16' if float16 else 'bf16'
|
||||
if getattr(args, 'eval_dataset', None) is not None and getattr(args, 'eval_strategy', 'no') == 'no':
|
||||
args.eval_strategy = 'steps'
|
||||
if getattr(args, 'eval_steps', None) is None: args.eval_steps = 0.1
|
||||
ga_steps = getattr(args, 'gradient_accumulation_steps', None)
|
||||
if ga_steps is not None and ga_steps > 1:
|
||||
from transformers import __version__ as transformers_version
|
||||
if Version(transformers_version) <= Version('4.45.2'):
|
||||
print('**** Unsloth: Please use our fixed gradient_accumulation_steps by updating transformers, TRL and Unsloth!\n'
|
||||
'`pip install --upgrade --no-cache-dir --force-reinstall --no-deps unsloth transformers trl unsloth_zoo`')
|
||||
if getattr(args, 'eval_strategy', 'no') != 'no':
|
||||
eval_bsz = getattr(args, 'per_device_eval_batch_size', 8)
|
||||
if eval_bsz == 8 and args.per_device_train_batch_size < eval_bsz: args.per_device_eval_batch_size = args.per_device_train_batch_size
|
||||
if getattr(args, 'eval_accumulation_steps', None) is None and ga_steps is not None: args.eval_accumulation_steps = ga_steps
|
||||
fp16_full_eval = getattr(args, 'fp16_full_eval', False)
|
||||
bf16_full_eval = getattr(args, 'bf16_full_eval', False)
|
||||
if args.fp16 and bf16_full_eval: args.bf16_full_eval = False; args.fp16_full_eval = True
|
||||
if args.bf16 and fp16_full_eval: args.bf16_full_eval = True; args.fp16_full_eval = False
|
||||
if force_float32:
|
||||
args.bf16_full_eval = False
|
||||
args.fp16_full_eval = False
|
||||
elif os.environ.get('UNSLOTH_MIXED_PRECISION', 'float32') == 'bfloat16':
|
||||
args.bf16_full_eval = True
|
||||
args.fp16_full_eval = False
|
||||
elif not bf16_full_eval and not fp16_full_eval:
|
||||
args.bf16_full_eval = args.bf16
|
||||
args.fp16_full_eval = args.fp16
|
||||
_output_logits = False
|
||||
if locals().get('compute_metrics', None) is not None: _output_logits = True
|
||||
if locals().get('preprocess_logits_for_metrics', None) is not None: _output_logits = True
|
||||
if _output_logits:
|
||||
os.environ['UNSLOTH_RETURN_LOGITS'] = '1'
|
||||
if 'max_seq_length' not in locals() and not hasattr(args, 'max_seq_length'):
|
||||
pass
|
||||
else:
|
||||
model_max_seq_length = getattr(model, 'max_seq_length', None)
|
||||
args_max_seq_length = getattr(args, 'max_seq_length', None)
|
||||
if args_max_seq_length is None and model_max_seq_length is not None:
|
||||
max_seq_length = model.max_seq_length
|
||||
if hasattr(args, 'max_seq_length'): args.max_seq_length = max_seq_length
|
||||
if model is not None and hasattr(model, 'for_training'):
|
||||
model.for_training()
|
||||
if 'tokenizer' in locals() and hasattr(tokenizer, 'padding_side'): tokenizer.padding_side = 'right'
|
||||
if 'processing_class' in locals():
|
||||
if hasattr(processing_class, 'padding_side'): processing_class.padding_side = 'right'
|
||||
if hasattr(processing_class, 'tokenizer') and hasattr(processing_class.tokenizer, 'padding_side'): processing_class.tokenizer.padding_side = 'right'
|
||||
__tokenizer = processing_class if 'processing_class' in locals() else tokenizer
|
||||
from unsloth_zoo.vision_utils import UnslothVisionDataCollator
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq) and 'labels' not in train_dataset.column_names:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
elif isinstance(data_collator, TransformersDataCollatorForLanguageModeling) and 'labels' in train_dataset.column_names:
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer)
|
||||
else:
|
||||
if hasattr(args, 'remove_unused_columns'): args.remove_unused_columns = False
|
||||
if hasattr(args, 'dataset_text_field'): args.dataset_text_field = ''
|
||||
if hasattr(args, 'dataset_kwargs'): args.dataset_kwargs = {'skip_prepare_dataset': True}
|
||||
if not isinstance(data_collator, UnslothVisionDataCollator):
|
||||
if not hasattr(__tokenizer, 'pad') and hasattr(__tokenizer, 'tokenizer'):
|
||||
if isinstance(data_collator, DataCollatorForSeq2Seq):
|
||||
data_collator = DataCollatorForSeq2Seq(__tokenizer.tokenizer)
|
||||
else:
|
||||
data_collator = TransformersDataCollatorForLanguageModeling(__tokenizer.tokenizer, mlm = False, mlm_probability = 0.0)
|
||||
other_metrics = []
|
||||
|
||||
from unsloth_zoo.logging_utils import PatchRLStatistics
|
||||
PatchRLStatistics('reward_trainer', other_metrics)
|
||||
|
||||
super().__init__(
|
||||
model = model,
|
||||
args = args,
|
||||
data_collator = data_collator,
|
||||
train_dataset = train_dataset,
|
||||
eval_dataset = eval_dataset,
|
||||
processing_class = processing_class,
|
||||
model_init = model_init,
|
||||
compute_metrics = compute_metrics,
|
||||
callbacks = callbacks,
|
||||
preprocess_logits_for_metrics = preprocess_logits_for_metrics,
|
||||
peft_config = peft_config,**kwargs)
|
||||
if hasattr(self, 'neftune_hook_handle'):
|
||||
self.neftune_hook_handle.remove()
|
||||
if hasattr(self, 'neftune_hook_handle'): del self.neftune_hook_handle
|
||||
if getattr(args, 'neftune_noise_alpha', None) is not None:
|
||||
model.get_input_embeddings().neftune_noise_alpha = self.neftune_noise_alpha
|
||||
pass
|
||||
|
||||
pass
|
||||
1141
unsloth_compiled_cache/UnslothSFTTrainer.py
Normal file
1141
unsloth_compiled_cache/UnslothSFTTrainer.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user