Source code for trinity.manager.checkpoint_converter
import os
from typing import Optional
from trinity.utils.log import get_logger
[docs]
class Converter:
[docs]
def __init__(self, base_model_dir: Optional[str] = None):
self.logger = get_logger(__name__)
self.base_model_dir = base_model_dir
self.base_model = None
self._init_process_group = False
self.checkpoint_converter = None
[docs]
def init_base_model(self) -> bool:
if not self.base_model_dir:
self.logger.error(
"Base model directory is not specified. "
"Please specify it with `--base-model-dir /path/to/model`."
)
return False
if self.base_model is not None:
return True
try:
self.base_model, _ = self._get_config_and_empty_model(self.base_model_dir)
except Exception:
self.logger.error(
f"Failed to initialize base model from {self.base_model_dir}", exc_info=True
)
return False
return True
[docs]
def init_process_group(self):
if self._init_process_group:
return
import torch
from verl.utils.device import get_nccl_backend
from verl.utils.distributed import set_numa_affinity
if "WORLD_SIZE" not in os.environ:
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
set_numa_affinity()
torch.distributed.init_process_group(get_nccl_backend())
self._init_process_group = True
[docs]
def init_checkpoint_converter(self, checkpoint_dir) -> bool:
if self.checkpoint_converter is not None:
return True
if not os.path.basename(checkpoint_dir).startswith("global_step_"):
self.logger.error(f"Invalid checkpoint directory {checkpoint_dir}.")
return False
actor_ckpt_dir = os.path.join(checkpoint_dir, "actor")
huggingface_dir = os.path.join(actor_ckpt_dir, "huggingface")
if not os.path.exists(os.path.join(huggingface_dir, "config.json")):
if not self.init_base_model():
self.logger.error(
f"Failed to load base model from {self.base_model_dir}, "
"please check if the model exists."
)
return False
self.base_model.config.save_pretrained(huggingface_dir)
from trinity.common.models.utils import get_megatron_converter
self.init_process_group()
self.checkpoint_converter = get_megatron_converter(actor_ckpt_dir)
return True
def _get_config_and_empty_model(self, model_dir: str):
import torch
import transformers
from accelerate import init_empty_weights
model_config = transformers.AutoConfig.from_pretrained(model_dir)
if "ForTokenClassification" in model_config.architectures[0]:
from transformers import AutoModelForTokenClassification
auto_model_cls = AutoModelForTokenClassification
elif "ForCausalLM" in model_config.architectures[0]:
from transformers import AutoModelForCausalLM
auto_model_cls = AutoModelForCausalLM
elif "ForConditionalGeneration" in model_config.architectures[0]:
# Handle different transformers versions for Vision2Seq models
import transformers
from packaging import version
if version.parse(transformers.__version__) >= version.parse("4.54.0"):
# transformers >= 4.54.0 uses AutoModelForImageTextToText
from transformers import AutoModelForImageTextToText
auto_model_cls = AutoModelForImageTextToText
else:
# transformers < 4.54.0 uses AutoModelForVision2Seq
from transformers import AutoModelForVision2Seq
auto_model_cls = AutoModelForVision2Seq
else:
raise NotImplementedError(f"Unknown architecture {model_config['architectures']}")
with init_empty_weights():
model = auto_model_cls.from_config(model_config, dtype=torch.bfloat16)
model.to_empty(device="cpu")
return model, auto_model_cls
[docs]
def convert(self, checkpoint_dir: str) -> None:
if os.path.basename(checkpoint_dir).startswith("global_step_"):
import torch
actor_ckpt_dir = os.path.join(checkpoint_dir, "actor")
huggingface_dir = os.path.join(actor_ckpt_dir, "huggingface")
model = None
if os.path.exists(huggingface_dir):
has_hf_checkpoint = True
try:
model, auto_model_cls = self._get_config_and_empty_model(huggingface_dir)
auto_model_cls.from_pretrained(huggingface_dir)
except Exception:
self.logger.debug(
f"Incomplete or invalid Hugging Face checkpoint in {huggingface_dir}, will re-convert.",
exc_info=True,
)
has_hf_checkpoint = False
if has_hf_checkpoint:
return
if model is None:
if not self.init_base_model():
self.logger.error(
f"Failed to load base model from {self.base_model_dir}, please check if the model exists."
)
return
model = self.base_model
self.logger.info(f"Converting {checkpoint_dir} to huggingface format...")
dist_cpkt_dir = os.path.join(actor_ckpt_dir, "dist_ckpt")
try:
if os.path.exists(dist_cpkt_dir): # megatron
if not self.init_checkpoint_converter(checkpoint_dir):
return
state_dict = self.checkpoint_converter.get_state_dict(actor_ckpt_dir)
else: # fsdp
from trinity.common.models.utils import (
load_fsdp_state_dict_from_verl_checkpoint,
)
state_dict = load_fsdp_state_dict_from_verl_checkpoint(actor_ckpt_dir)
except Exception:
self.logger.error(
f"Failed to convert {checkpoint_dir} to huggingface format.",
exc_info=True,
)
return
state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()}
model.save_pretrained(huggingface_dir, state_dict=state_dict)
self.logger.info(f"Saved huggingface checkpoint to {huggingface_dir}")
else: # recursive search
for sub_dir in os.listdir(checkpoint_dir):
sub_dir_path = os.path.join(checkpoint_dir, sub_dir)
if os.path.isdir(sub_dir_path):
self.convert(sub_dir_path)