trinity.common.models.model module#
Base Model Class
- class trinity.common.models.model.InferenceModel(config: InferenceModelConfig)[source]#
Bases:
ABCA model for high performance for rollout inference.
- __init__(config: InferenceModelConfig) None[source]#
- async generate(prompt: str, **kwargs) Sequence[Experience][source]#
Generate a responses from a prompt in async.
- async chat(messages: List[dict], **kwargs) Sequence[Experience][source]#
Generate experiences from a list of history chat messages in async.
- async logprobs(token_ids: List[int], **kwargs) Tensor[source]#
Generate logprobs for a list of tokens in async.
- async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#
Convert a list of messages into an experience in async.
- abstractmethod async sync_model(model_version: int) int[source]#
Sync the model with the latest model_version.
- get_model_config() InferenceModelConfig[source]#
Get the model configuration.
- class trinity.common.models.model.BaseInferenceModel(config: InferenceModelConfig)[source]#
Bases:
InferenceModelBase class for inference models containing common logic.
- __init__(config: InferenceModelConfig) None[source]#
- async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#
Convert a list of messages into an experience in async.
- Parameters:
messages – List of message dictionaries
tools – Optional list of tools
temperature – Optional temperature for logprobs calculation
- class trinity.common.models.model.ModelWrapper(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[source]#
Bases:
objectA wrapper for the InferenceModel Ray Actor
- __init__(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[source]#
Initialize the ModelWrapper.
- Parameters:
model (InferenceModel) – The inference model Ray actor.
enable_lora (bool) – Whether to enable LoRA. Default to False.
enable_history (bool) – Whether to enable history recording. Default to False.
- logprobs(tokens: List[int], temperature: float | None = None) Tensor[source]#
Calculate the logprobs of the given tokens.
- async logprobs_async(tokens: List[int], temperature: float | None = None) Tensor[source]#
Calculate the logprobs of the given tokens in async.
- convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#
Convert a list of messages into an experience.
- async convert_messages_to_experience_async(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#
Convert a list of messages into an experience in async.
- property api_key: str#
Get the API key.
- property model_version: int#
Get the version of the model.
- property model_version_async: int#
Get the version of the model.
- property model_path: str#
Returns the path to the model files based on the current engine type.
For ‘vllm’ engine: returns the model path from the configuration (config.model_path)
For ‘tinker’ engine: returns the path to the most recent sampler weights
- property model_path_async: str#
Returns the path to the model files based on the current engine type.
For ‘vllm’ engine: returns the model path from the configuration (config.model_path)
For ‘tinker’ engine: returns the path to the most recent sampler weights
- property model_name: str | None#
Get the name of the model.
- property model_config: InferenceModelConfig#
Get the model config.
- property generate_kwargs: Dict[str, Any]#
Get the generation kwargs for openai client.
- get_openai_client() openai.OpenAI[source]#
Get the openai client.
- Returns:
The openai client. And model_path is added to the client which refers to the model path.
- Return type:
openai.OpenAI
- get_openai_async_client() openai.AsyncOpenAI[source]#
Get the async openai client.
- Returns:
The async openai client. And model_path is added to the client which refers to the model path.
- Return type:
openai.AsyncOpenAI
- extract_experience_from_history(clear_history: bool = True) List[Experience][source]#
Extract experiences from the history.
- clone_with_isolated_history() ModelWrapper[source]#
Clone the current ModelWrapper with isolated history.
- trinity.common.models.model.convert_api_output_to_experience(output) List[Experience][source]#
Convert non-stream/stream API outputs to a list of experiences.
- class trinity.common.models.model.HistoryRecordingStream(stream, history: List[Experience], is_async: bool = False)[source]#
Bases:
object- __init__(stream, history: List[Experience], is_async: bool = False) None[source]#