# -*- coding: utf-8 -*-
"""Base Model Class"""
import asyncio
import copy
import socket
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple, Union
import httpx
import ray
import torch
from torch import Tensor
from trinity.common.config import InferenceModelConfig
from trinity.common.constants import RunningStatus
from trinity.common.experience import Experience
from trinity.common.models.utils import get_action_mask_method
from trinity.utils.log import get_logger
if TYPE_CHECKING:
import openai
[docs]
class InferenceModel(ABC):
"""A model for high performance for rollout inference."""
[docs]
def __init__(self, config: InferenceModelConfig) -> None:
self.config = config
self.logger = get_logger(__name__)
[docs]
async def generate(self, prompt: str, **kwargs) -> Sequence[Experience]:
"""Generate a responses from a prompt in async."""
raise NotImplementedError
[docs]
async def chat(self, messages: List[dict], **kwargs) -> Sequence[Experience]:
"""Generate experiences from a list of history chat messages in async."""
raise NotImplementedError
[docs]
async def logprobs(self, token_ids: List[int], **kwargs) -> Tensor:
"""Generate logprobs for a list of tokens in async."""
raise NotImplementedError
[docs]
async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
raise NotImplementedError
[docs]
async def prepare(self) -> None:
"""Prepare the model before inference."""
pass
[docs]
@abstractmethod
async def sync_model(self, model_version: int) -> int:
"""Sync the model with the latest model_version."""
[docs]
@abstractmethod
def get_model_version(self) -> int:
"""Get the checkpoint version."""
[docs]
def get_available_address(self) -> Tuple[str, int]:
"""Get the address of the actor."""
address = ray.util.get_node_ip_address()
with socket.socket() as s:
s.bind(("", 0))
port = s.getsockname()[1]
return address, port
[docs]
def get_api_server_url(self) -> Optional[str]:
"""Get the API server URL if available."""
return None
[docs]
def get_api_key(self) -> str:
"""Get the API key."""
return "EMPTY"
[docs]
def get_model_config(self) -> InferenceModelConfig:
"""Get the model configuration."""
return self.config
[docs]
def get_model_path(self) -> Optional[str]:
"""Get the model path"""
return self.config.model_path
[docs]
async def shutdown(self) -> None:
"""Shutdown the model and release resources."""
pass
[docs]
class BaseInferenceModel(InferenceModel):
"""Base class for inference models containing common logic."""
[docs]
def __init__(self, config: InferenceModelConfig) -> None:
super().__init__(config)
self.tokenizer = None
self.chat_template = None
if self.config.chat_template:
self.chat_template = self.config.chat_template
self.action_mask_method = get_action_mask_method(self.chat_template)
self.enable_thinking = config.enable_thinking
[docs]
def apply_chat_template(
self,
tokenizer_or_processor,
messages: List[dict],
) -> str:
assert tokenizer_or_processor is not None, "tokenizer_or_processor must be provided."
if messages[-1]["role"] == "assistant":
prompt = tokenizer_or_processor.apply_chat_template(
messages,
tokenize=False,
continue_final_message=True,
chat_template=self.chat_template,
)
else:
prompt = tokenizer_or_processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
)
return prompt
def _handle_prompt_truncation(self, prompt: str, **kwargs) -> Tuple[Sequence, bool]:
"""Handle prompt truncation if needed."""
# Tokenize once without truncation to check if truncation is needed
token_ids = self.tokenizer( # type: ignore
prompt,
truncation=False,
return_tensors="pt",
)[
"input_ids"
][0].tolist()
# Check if truncation is needed and apply it
if (
self.config.enable_prompt_truncation
and self.config.max_prompt_tokens is not None
and len(token_ids) > self.config.max_prompt_tokens
):
self.logger.warning(f"Prompt was truncated to {self.config.max_prompt_tokens} tokens")
token_ids = token_ids[: self.config.max_prompt_tokens + 1] # leave one for response
return [
Experience(
tokens=token_ids,
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=len(token_ids) - 1,
prompt_text=self.tokenizer.decode(token_ids[:-1]),
response_text=self.tokenizer.decode(token_ids[-1]),
truncate_status="prompt_truncated",
reward=0.0,
)
for _ in range(kwargs.get("n", 1))
], False
return token_ids, True
[docs]
async def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async.
Args:
messages: List of message dictionaries
tools: Optional list of tools
temperature: Optional temperature for logprobs calculation
"""
if self.tokenizer is None:
await self._initialize_tokenizer()
token_ids, action_mask, prompt_length = self.action_mask_method(
tokenizer=self.tokenizer,
messages=messages,
tools=tools,
chat_template=self.chat_template,
enable_thinking=self.enable_thinking,
) # (seq_length, ), (seq_length, )
assert token_ids is not None
truncate_status = None
# Truncate prompt if it exceeds max_prompt_tokens
if (
self.config.enable_prompt_truncation
and self.config.max_prompt_tokens is not None
and prompt_length > self.config.max_prompt_tokens
):
truncate_status = "prompt_truncated"
self.logger.warning(
f"Warning: {prompt_length=} exceeds the length limit {self.config.max_prompt_tokens}, "
f"this experience will be not counted in the loss computation."
)
return Experience(
tokens=token_ids[: self.config.max_prompt_tokens + 1],
logprobs=torch.zeros(1, dtype=torch.float32),
prompt_length=self.config.max_prompt_tokens, # Use truncated length
action_mask=torch.zeros(1, dtype=torch.bool), # ignored in loss computation
messages=messages, # messages are not truncated
truncate_status=truncate_status,
)
# Truncate response if it exceeds max_model_len
max_model_len = self.config.max_model_len
if max_model_len is not None and len(token_ids) > max_model_len - 1:
truncate_status = "response_truncated"
self.logger.warning(
f"Warning: {len(token_ids)=} exceeds the length limit {(max_model_len - 1)=}"
)
token_ids = token_ids[: max_model_len - 1]
action_mask = action_mask[: max_model_len - 1]
temperature = temperature if temperature is not None else self.config.temperature
logprobs = await self.logprobs(
token_ids=token_ids.tolist(), temperature=temperature
) # (seq_length - 1,)
return Experience(
tokens=token_ids,
logprobs=logprobs[prompt_length - 1 :],
prompt_length=prompt_length,
action_mask=action_mask[prompt_length:], # Exclude the prompt tokens
messages=messages,
truncate_status=truncate_status,
)
def _history_recorder(func):
"""Decorator to record history of the model calls."""
async def async_wrapper(self, *args, **kwargs):
result = await func(self, *args, **kwargs)
if self.enable_history:
self._record_history(result)
return result
def sync_wrapper(self, *args, **kwargs):
result = func(self, *args, **kwargs)
if self.enable_history:
self._record_history(result)
return result
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
[docs]
class ModelWrapper:
"""A wrapper for the InferenceModel Ray Actor"""
[docs]
def __init__(
self,
model: InferenceModel,
enable_lora: bool = False,
enable_history: bool = False,
):
"""Initialize the ModelWrapper.
Args:
model (InferenceModel): The inference model Ray actor.
enable_lora (bool): Whether to enable LoRA. Default to False.
enable_history (bool): Whether to enable history recording. Default to False.
"""
self.model = model
self.config: InferenceModelConfig = None # init during prepare
self._model_name: str = None
self.api_address: str = None
self._api_key: str = None
self.openai_client: openai.OpenAI = None
self.openai_async_client: openai.AsyncOpenAI = None
self.logger = get_logger(__name__)
self.enable_lora = enable_lora
self.enable_history = enable_history
self.history = []
self.status = RunningStatus.RUNNING
self.workflow_state: Dict = {}
self.request_count = 0
self.state_lock = asyncio.Lock()
[docs]
async def prepare(self) -> None:
"""Prepare the model wrapper."""
self.config = await self.model.get_model_config.remote()
self._model_name = self.config.name
self._api_key = await self.model.get_api_key.remote()
self._engine_type = self.config.engine_type
self._generate_kwargs = {
"temperature": self.config.temperature,
"top_p": self.config.top_p,
"max_tokens": self.config.max_response_tokens,
}
if self.config.enable_thinking is not None:
self._generate_kwargs["extra_body"] = {
"chat_template_kwargs": {"enable_thinking": self.config.enable_thinking}
}
self.api_address = await self.model.get_api_server_url.remote()
if self.api_address is None:
self.logger.info("API server is not enabled for inference model.")
return
if self._engine_type == "tinker":
return
max_retries = 30
interval = 2 # seconds
for i in range(max_retries):
try:
async with httpx.AsyncClient() as client:
response = await client.get(self.api_address + "/health", timeout=5)
if response.status_code == 200:
return
except Exception as e:
self.logger.info(f"API server not ready (attempt {i + 1}/{max_retries}): {e}")
await asyncio.sleep(interval)
raise RuntimeError(
f"API server at {self.api_address} not ready after {max_retries} attempts."
)
def _record_history(self, exps: Union[Experience, List[Experience]]) -> None:
"""Record experiences to history."""
if isinstance(exps, Experience):
self.history.append(exps)
elif isinstance(exps, list):
self.history.extend(exps)
else:
raise TypeError("Expected Experience or List[Experience], got {}".format(type(exps)))
[docs]
@_history_recorder
def generate(self, prompts: List[str], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of prompts."""
lora_request = self.get_lora_request()
results = ray.get(
[self.model.generate.remote(prompt, lora_request, **kwargs) for prompt in prompts]
)
return [exp for exps in results for exp in exps]
[docs]
@_history_recorder
async def generate_async(self, prompts: List[str], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of prompts in async."""
lora_request = await self.get_lora_request_async()
results = await asyncio.gather(
*[self.model.generate.remote(prompt, lora_request, **kwargs) for prompt in prompts]
)
return [exp for exps in results for exp in exps]
[docs]
@_history_recorder
def chat(self, messages: List[dict], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of messages."""
lora_request = self.get_lora_request()
return ray.get(self.model.chat.remote(messages, lora_request=lora_request, **kwargs))
[docs]
@_history_recorder
async def chat_async(self, messages: List[dict], **kwargs) -> List[Experience]:
"""Generate a list of experiences from a list of messages in async."""
lora_request = await self.get_lora_request_async()
return await self.model.chat.remote(messages, lora_request=lora_request, **kwargs)
[docs]
def logprobs(self, tokens: List[int], temperature: Optional[float] = None) -> Tensor:
"""Calculate the logprobs of the given tokens."""
return ray.get(self.model.logprobs.remote(tokens, temperature=temperature))
[docs]
async def logprobs_async(
self, tokens: List[int], temperature: Optional[float] = None
) -> Tensor:
"""Calculate the logprobs of the given tokens in async."""
return await self.model.logprobs.remote(tokens, temperature=temperature)
[docs]
def convert_messages_to_experience(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience."""
return ray.get(
self.model.convert_messages_to_experience.remote(
messages, tools=tools, temperature=temperature
)
)
[docs]
async def convert_messages_to_experience_async(
self,
messages: List[dict],
tools: Optional[List[dict]] = None,
temperature: Optional[float] = None,
) -> Experience:
"""Convert a list of messages into an experience in async."""
return await self.model.convert_messages_to_experience.remote(
messages, tools=tools, temperature=temperature
)
@property
def api_key(self) -> str:
"""Get the API key."""
return self._api_key
@property
def model_version(self) -> int:
"""Get the version of the model."""
return ray.get(self.model.get_model_version.remote())
@property
async def model_version_async(self) -> int:
"""Get the version of the model."""
return await self.model.get_model_version.remote()
@property
def model_path(self) -> str:
"""
Returns the path to the model files based on the current engine type.
- For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
- For 'tinker' engine: returns the path to the most recent sampler weights
"""
return ray.get(self.model.get_model_path.remote())
@property
async def model_path_async(self) -> str:
"""
Returns the path to the model files based on the current engine type.
- For 'vllm' engine: returns the model path from the configuration (`config.model_path`)
- For 'tinker' engine: returns the path to the most recent sampler weights
"""
return await self.model.get_model_path.remote()
@property
def model_name(self) -> Optional[str]:
"""Get the name of the model."""
return self._model_name
@property
def model_config(self) -> InferenceModelConfig:
"""Get the model config."""
return self.config
@property
def generate_kwargs(self) -> Dict[str, Any]:
"""Get the generation kwargs for openai client."""
return self._generate_kwargs
[docs]
def get_lora_request(self) -> Any:
if self.enable_lora:
return ray.get(self.model.get_lora_request.remote())
else:
return None
[docs]
async def get_lora_request_async(self) -> Any:
if self.enable_lora:
return await self.model.get_lora_request.remote()
else:
return None
[docs]
async def get_message_token_len(self, messages: List[dict]) -> int:
return await self.model.get_message_token_len.remote(messages)
[docs]
def get_openai_client(self) -> "openai.OpenAI":
"""Get the openai client.
Returns:
openai.OpenAI: The openai client. And `model_path` is added to the client which refers to the model path.
"""
import openai
if self.openai_client is not None:
setattr(self.openai_client, "model_path", self.model_path)
return self.openai_client
if not self.api_address:
raise ValueError(
"API server is not enabled for this model. OpenAI client is unavailable."
)
self.openai_client = openai.OpenAI(
base_url=f"{self.api_address}/v1",
api_key=self._api_key,
)
if self._engine_type == "tinker":
# ! TODO: because tinker's OpenAI API interface is in beta,
# we need to use original API in thinker instead.
def chat_completions(*args, **kwargs):
messages = kwargs.pop("messages")
chat_response = ray.get(
self.model.chat.remote(
messages=messages,
with_chat_completion=True,
return_token_ids=self.enable_history,
**kwargs,
)
)
response = chat_response.pop()
if self.enable_history:
self.history.extend(chat_response)
return response
self.openai_client.chat.completions.create = chat_completions
elif self.enable_history:
# add a decorator to the openai client to record history
ori_create = self.openai_client.chat.completions.create
def record_chat_completions(*args, **kwargs):
logprobs = kwargs.pop("logprobs", True)
extra_body = dict(kwargs.pop("extra_body", {}))
if self.config.enable_thinking is not None:
chat_template_kwargs = dict(extra_body.get("chat_template_kwargs", {}))
chat_template_kwargs["enable_thinking"] = self.config.enable_thinking
extra_body["chat_template_kwargs"] = chat_template_kwargs
extra_body["return_token_ids"] = True
response = ori_create(*args, extra_body=extra_body, logprobs=logprobs, **kwargs)
if kwargs.get("stream", False):
return HistoryRecordingStream(response, self.history, is_async=False)
self.history.extend(convert_api_output_to_experience(response))
return response
self.openai_client.chat.completions.create = record_chat_completions
setattr(self.openai_client, "model_path", self.model_path)
return self.openai_client
[docs]
def get_openai_async_client(self) -> "openai.AsyncOpenAI":
"""Get the async openai client.
Returns:
openai.AsyncOpenAI: The async openai client. And `model_path` is added to the client which refers to the model path.
"""
import openai
if self.openai_async_client is not None:
setattr(self.openai_async_client, "model_path", self.model_path)
return self.openai_async_client
if not self.api_address:
raise ValueError(
"API server is not enabled for this model. OpenAI async client is unavailable."
)
# first make sure that we have the sync openai client
self.openai_async_client = openai.AsyncOpenAI(
base_url=f"{self.api_address}/v1",
api_key=self._api_key,
)
if self._engine_type == "tinker":
# ! TODO: because tinker's OpenAI API interface is in beta,
# we need to use original API in thinker instead.
async def chat_completions(*args, **kwargs):
messages = kwargs.pop("messages")
chat_response = await self.model.chat.remote(
messages=messages,
with_chat_completion=True,
return_token_ids=self.enable_history,
**kwargs,
)
response = chat_response.pop()
if self.enable_history:
self.history.extend(chat_response)
return response
self.openai_async_client.chat.completions.create = chat_completions
elif self.enable_history:
# add a decorator to the openai client to record history
ori_create = self.openai_async_client.chat.completions.create
async def record_chat_completions(*args, **kwargs):
logprobs = kwargs.pop("logprobs", True)
extra_body = dict(kwargs.pop("extra_body", {}))
if self.config.enable_thinking is not None:
chat_template_kwargs = dict(extra_body.get("chat_template_kwargs", {}))
chat_template_kwargs["enable_thinking"] = self.config.enable_thinking
extra_body["chat_template_kwargs"] = chat_template_kwargs
extra_body["return_token_ids"] = True
response = await ori_create(
*args, extra_body=extra_body, logprobs=logprobs, **kwargs
)
if kwargs.get("stream", False):
return HistoryRecordingStream(response, self.history, is_async=True)
self.history.extend(convert_api_output_to_experience(response))
return response
self.openai_async_client.chat.completions.create = record_chat_completions
# get model_path from the sync openai client to avoid async call here
setattr(self.openai_async_client, "model_path", self.model_path)
return self.openai_async_client
[docs]
async def get_current_load(self) -> int:
"""Get the current load metrics of the model."""
if not self.api_address:
raise ValueError(
"API server is not enabled for this model. Load metrics is unavailable."
)
async with httpx.AsyncClient() as client:
response = await client.get(f"{self.api_address}/load")
data = response.json()
return data["server_load"]
[docs]
async def sync_model_weights(self, model_version: int) -> None:
"""Sync the model weights"""
await self.model.sync_model.remote(model_version)
# Workflow state management methods
[docs]
async def set_workflow_state(self, state: Dict) -> None:
"""Set the state of workflow using the model."""
async with self.state_lock:
self.workflow_state.update(state)
[docs]
async def clean_workflow_state(self) -> None:
"""Clean the state of workflow using the model."""
async with self.state_lock:
self.workflow_state = {}
self.history.clear()
[docs]
async def get_workflow_state(self) -> Dict:
"""Get the state of workflow using the model."""
async with self.state_lock:
return self.workflow_state.copy()
[docs]
def clone_with_isolated_history(self) -> "ModelWrapper":
"""Clone the current ModelWrapper with isolated history."""
new_wrapper = copy.copy(self)
new_wrapper.openai_async_client = None
new_wrapper.openai_client = None
new_wrapper.history = []
return new_wrapper
[docs]
def convert_api_output_to_experience(
output,
) -> List[Experience]:
"""Convert non-stream/stream API outputs to a list of experiences."""
if hasattr(output, "choices"):
return _convert_completion_output_to_experience(output)
return _convert_stream_chunks_to_experience(output)
[docs]
class HistoryRecordingStream:
[docs]
def __init__(self, stream, history: List[Experience], is_async: bool = False) -> None:
self._stream = stream
self._history = history
self._chunks = []
self._recorded = False
self._is_async = is_async
if is_async:
self._iterator = stream.__aiter__()
else:
self._iterator = iter(stream)
# --- Sync methods ---
def __iter__(self):
if self._is_async:
raise TypeError("Use 'async for' for async streams.")
return self
def __next__(self):
if self._is_async:
raise TypeError("Use 'async for' for async streams.")
try:
chunk = next(self._iterator)
except StopIteration:
self._record_history_once()
raise
self._chunks.append(chunk)
return chunk
[docs]
def close(self) -> None:
if self._is_async:
raise TypeError("Use 'aclose' for async streams.")
self._record_history_once()
close_fn = getattr(self._stream, "close", None)
if callable(close_fn):
close_fn()
# --- Async methods ---
def __aiter__(self):
if not self._is_async:
raise TypeError("Use 'for' for sync streams.")
return self
async def __anext__(self):
if not self._is_async:
raise TypeError("Use 'for' for sync streams.")
try:
chunk = await self._iterator.__anext__()
except StopAsyncIteration:
self._record_history_once()
raise
self._chunks.append(chunk)
return chunk
[docs]
async def aclose(self) -> None:
if not self._is_async:
raise TypeError("Use 'close' for sync streams.")
self._record_history_once()
close_fn = getattr(self._stream, "aclose", None)
if callable(close_fn):
close_result = close_fn()
if hasattr(close_result, "__await__"):
await close_result
return
close_fn = getattr(self._stream, "close", None)
if callable(close_fn):
close_fn()
def _record_history_once(self) -> None:
if self._recorded:
return
self._recorded = True
if self._chunks:
self._history.extend(convert_api_output_to_experience(self._chunks))
def __getattr__(self, name: str):
return getattr(self._stream, name)
def _convert_completion_output_to_experience(output) -> List[Experience]:
"""Convert non-stream chat completion output to experiences."""
return [
Experience(
tokens=torch.cat(
(
torch.tensor(output.prompt_token_ids, dtype=torch.int32),
torch.tensor(choice.token_ids, dtype=torch.int32),
)
),
logprobs=extract_logprobs(choice),
prompt_length=len(output.prompt_token_ids),
response_text=getattr(choice.message, "content", None),
)
for choice in output.choices
]
def _convert_stream_chunks_to_experience(chunks: Sequence[Any]) -> List[Experience]:
"""Convert streamed chat completion chunks to experiences."""
prompt_token_ids: Optional[List[int]] = None
by_choice: Dict[int, Dict[str, Any]] = {}
for chunk in chunks:
if prompt_token_ids is None and hasattr(chunk, "prompt_token_ids"):
chunk_prompt_token_ids = getattr(chunk, "prompt_token_ids", None)
if chunk_prompt_token_ids is not None:
prompt_token_ids = list(chunk_prompt_token_ids)
for choice in getattr(chunk, "choices", []) or []:
idx = getattr(choice, "index", 0)
if idx not in by_choice:
by_choice[idx] = {
"token_ids": [],
"logprobs": [],
"response_text_parts": [],
}
data = by_choice[idx]
token_ids = getattr(choice, "token_ids", None)
if token_ids is not None:
data["token_ids"].extend(token_ids)
choice_logprobs = getattr(choice, "logprobs", None)
if (
choice_logprobs is not None
and getattr(choice_logprobs, "content", None) is not None
):
for token_logprob in choice_logprobs.content:
data["logprobs"].append(token_logprob.logprob)
if token_ids is None:
token_id = getattr(token_logprob, "token_id", None)
if token_id is not None:
data["token_ids"].append(token_id)
delta = getattr(choice, "delta", None)
if delta is not None:
delta_content = getattr(delta, "content", None)
if isinstance(delta_content, str) and len(delta_content) > 0:
data["response_text_parts"].append(delta_content)
prompt_token_ids = prompt_token_ids or []
exps: List[Experience] = []
for idx in sorted(by_choice.keys()):
data = by_choice[idx]
response_token_ids = data["token_ids"]
if len(response_token_ids) == 0:
continue
response_text = "".join(data["response_text_parts"])
exps.append(
Experience(
tokens=torch.tensor(prompt_token_ids + response_token_ids, dtype=torch.int32),
logprobs=torch.tensor(data["logprobs"], dtype=torch.float32),
prompt_length=len(prompt_token_ids),
response_text=response_text,
)
)
return exps