trinity.common.models.model module

Contents

trinity.common.models.model module#

Base Model Class

class trinity.common.models.model.InferenceModel(config: InferenceModelConfig)[source]#

Bases: ABC

A model for high performance for rollout inference.

__init__(config: InferenceModelConfig) None[source]#
async generate(prompt: str, **kwargs) Sequence[Experience][source]#

Generate a responses from a prompt in async.

async chat(messages: List[dict], **kwargs) Sequence[Experience][source]#

Generate experiences from a list of history chat messages in async.

async logprobs(token_ids: List[int], **kwargs) Tensor[source]#

Generate logprobs for a list of tokens in async.

async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#

Convert a list of messages into an experience in async.

async prepare() None[source]#

Prepare the model before inference.

abstractmethod async sync_model(model_version: int) int[source]#

Sync the model with the latest model_version.

abstractmethod get_model_version() int[source]#

Get the checkpoint version.

get_available_address() Tuple[str, int][source]#

Get the address of the actor.

get_api_server_url() str | None[source]#

Get the API server URL if available.

get_api_key() str[source]#

Get the API key.

get_model_config() InferenceModelConfig[source]#

Get the model configuration.

get_model_path() str | None[source]#

Get the model path

async shutdown() None[source]#

Shutdown the model and release resources.

class trinity.common.models.model.BaseInferenceModel(config: InferenceModelConfig)[source]#

Bases: InferenceModel

Base class for inference models containing common logic.

__init__(config: InferenceModelConfig) None[source]#
apply_chat_template(tokenizer_or_processor, messages: List[dict]) str[source]#
async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#

Convert a list of messages into an experience in async.

Parameters:
  • messages – List of message dictionaries

  • tools – Optional list of tools

  • temperature – Optional temperature for logprobs calculation

class trinity.common.models.model.ModelWrapper(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[source]#

Bases: object

A wrapper for the InferenceModel Ray Actor

__init__(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[source]#

Initialize the ModelWrapper.

Parameters:
  • model (InferenceModel) – The inference model Ray actor.

  • enable_lora (bool) – Whether to enable LoRA. Default to False.

  • enable_history (bool) – Whether to enable history recording. Default to False.

async prepare() None[source]#

Prepare the model wrapper.

generate(*args, **kwargs)[source]#
async generate_async(*args, **kwargs)[source]#
chat(*args, **kwargs)[source]#
async chat_async(*args, **kwargs)[source]#
logprobs(tokens: List[int], temperature: float | None = None) Tensor[source]#

Calculate the logprobs of the given tokens.

async logprobs_async(tokens: List[int], temperature: float | None = None) Tensor[source]#

Calculate the logprobs of the given tokens in async.

convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#

Convert a list of messages into an experience.

async convert_messages_to_experience_async(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[source]#

Convert a list of messages into an experience in async.

property api_key: str#

Get the API key.

property model_version: int#

Get the version of the model.

property model_version_async: int#

Get the version of the model.

property model_path: str#

Returns the path to the model files based on the current engine type.

  • For ‘vllm’ engine: returns the model path from the configuration (config.model_path)

  • For ‘tinker’ engine: returns the path to the most recent sampler weights

property model_path_async: str#

Returns the path to the model files based on the current engine type.

  • For ‘vllm’ engine: returns the model path from the configuration (config.model_path)

  • For ‘tinker’ engine: returns the path to the most recent sampler weights

property model_name: str | None#

Get the name of the model.

property model_config: InferenceModelConfig#

Get the model config.

property generate_kwargs: Dict[str, Any]#

Get the generation kwargs for openai client.

get_lora_request() Any[source]#
async get_lora_request_async() Any[source]#
async get_message_token_len(messages: List[dict]) int[source]#
get_openai_client() openai.OpenAI[source]#

Get the openai client.

Returns:

The openai client. And model_path is added to the client which refers to the model path.

Return type:

openai.OpenAI

get_openai_async_client() openai.AsyncOpenAI[source]#

Get the async openai client.

Returns:

The async openai client. And model_path is added to the client which refers to the model path.

Return type:

openai.AsyncOpenAI

async get_current_load() int[source]#

Get the current load metrics of the model.

async sync_model_weights(model_version: int) None[source]#

Sync the model weights

extract_experience_from_history(clear_history: bool = True) List[Experience][source]#

Extract experiences from the history.

async set_workflow_state(state: Dict) None[source]#

Set the state of workflow using the model.

async clean_workflow_state() None[source]#

Clean the state of workflow using the model.

async get_workflow_state() Dict[source]#

Get the state of workflow using the model.

clone_with_isolated_history() ModelWrapper[source]#

Clone the current ModelWrapper with isolated history.

trinity.common.models.model.convert_api_output_to_experience(output) List[Experience][source]#

Convert non-stream/stream API outputs to a list of experiences.

class trinity.common.models.model.HistoryRecordingStream(stream, history: List[Experience], is_async: bool = False)[source]#

Bases: object

__init__(stream, history: List[Experience], is_async: bool = False) None[source]#
close() None[source]#
async aclose() None[source]#
trinity.common.models.model.extract_logprobs(choice) Tensor[source]#

Extract logprobs from a list of logprob dictionaries.