Source code for trinity.common.patch.kimi
"""Monkey patching for 'kimi_vl' models."""
[docs]
def kimi_vl_monkey_patch_decorator(func):
"""
A decorator that applies temporary monkey patches for 'kimi_vl' models before
the decorated function runs, and restores the original state afterward.
The patch is applied only if:
- The model's config.json exists and specifies "model_type": "kimi_vl"
- The installed transformers version is >= 4.51.0
Patches include:
1. Replacing `transformers.activations.PytorchGELUTanh` with `GELUTanh`
2. Wrapping `importlib.util.spec_from_file_location` to inject DeepseekV3 classes
The decorator automatically extracts `model_path` and `override_model_config`
from the function's arguments using `inspect.signature`, regardless of whether
they are passed as positional or keyword arguments.
"""
import importlib
import inspect
import json
import os
from functools import wraps
import transformers
from packaging import version
transformers_version = transformers.__version__
sig = inspect.signature(func) # Analyze function signature once at decoration time
@wraps(func)
def wrapper(*args, **kwargs):
# Bind actual arguments to parameter names (handles pos/kw/defaults)
bound_args = sig.bind(*args, **kwargs)
bound_args.apply_defaults()
# Extract required parameters safely by name
if "model_path" in bound_args.arguments: # actor/ref worker
model_path = bound_args.arguments["model_path"]
elif "model_config" in bound_args.arguments: # verl config check
model_path = bound_args.arguments["model_config"].path
elif "self" in bound_args.arguments: # critic worker
model_path = bound_args.arguments["self"].config.model.path
# Track patch state for cleanup
kimi_vl_patch_applied = False
origin_spec_from_file_location = None
origin_PytorchGELUTanh = None
try:
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
json_hf_config = json.load(f)
# Check if model requires special patching
if json_hf_config.get("model_type") == "kimi_vl" and version.parse(
transformers_version
) >= version.parse("4.51.0"):
# Save original values for restoration
origin_PytorchGELUTanh = getattr(
transformers.activations, "PytorchGELUTanh", None
)
origin_spec_from_file_location = importlib.util.spec_from_file_location
# Patch 1: Replace PytorchGELUTanh
transformers.activations.PytorchGELUTanh = transformers.activations.GELUTanh
# Patch 2: Wrap spec_from_file_location to inject DeepseekV3 classes
def patched_spec_from_file_location(*args_spec, **kwargs_spec):
spec = origin_spec_from_file_location(*args_spec, **kwargs_spec)
if spec and hasattr(spec, "loader") and spec.loader:
original_exec_module = spec.loader.exec_module
def patched_exec_module(module):
original_exec_module(module)
# Inject DeepseekV3* classes from transformers into the module
for attr_name in dir(module):
if attr_name.startswith("DeepseekV3") and hasattr(
transformers, attr_name
):
setattr(module, attr_name, getattr(transformers, attr_name))
elif attr_name in {
"KimiVLPreTrainedModel",
"KimiVLForConditionalGeneration",
}:
setattr(
getattr(module, attr_name),
"supports_gradient_checkpointing",
True,
)
setattr(getattr(module, attr_name), "_supports_sdpa", True)
spec.loader.exec_module = patched_exec_module
return spec
importlib.util.spec_from_file_location = patched_spec_from_file_location
kimi_vl_patch_applied = True
# Call the original function
return func(*args, **kwargs)
finally:
# Always restore original state, even if an exception occurred
if kimi_vl_patch_applied:
# Restore PytorchGELUTanh
if origin_PytorchGELUTanh is not None:
transformers.activations.PytorchGELUTanh = origin_PytorchGELUTanh
else:
# Remove attribute if it didn't exist originally
if hasattr(transformers.activations, "PytorchGELUTanh"):
delattr(transformers.activations, "PytorchGELUTanh")
# Restore spec_from_file_location
if origin_spec_from_file_location is not None:
importlib.util.spec_from_file_location = origin_spec_from_file_location
return wrapper