trinity.common.models.model module

目录

trinity.common.models.model module#

Base Model Class

class trinity.common.models.model.InferenceModel(config: InferenceModelConfig)[源代码]#

基类:ABC

A model for high performance for rollout inference.

__init__(config: InferenceModelConfig) None[源代码]#
async generate(prompt: str, **kwargs) Sequence[Experience][源代码]#

Generate a responses from a prompt in async.

async chat(messages: List[dict], **kwargs) Sequence[Experience][源代码]#

Generate experiences from a list of history chat messages in async.

async logprobs(token_ids: List[int], **kwargs) Tensor[源代码]#

Generate logprobs for a list of tokens in async.

async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#

Convert a list of messages into an experience in async.

async prepare() None[源代码]#

Prepare the model before inference.

abstractmethod async sync_model(model_version: int) int[源代码]#

Sync the model with the latest model_version.

abstractmethod get_model_version() int[源代码]#

Get the checkpoint version.

get_available_address() Tuple[str, int][源代码]#

Get the address of the actor.

get_api_server_url() str | None[源代码]#

Get the API server URL if available.

get_api_key() str[源代码]#

Get the API key.

get_model_config() InferenceModelConfig[源代码]#

Get the model configuration.

get_model_path() str | None[源代码]#

Get the model path

async shutdown() None[源代码]#

Shutdown the model and release resources.

class trinity.common.models.model.BaseInferenceModel(config: InferenceModelConfig)[源代码]#

基类:InferenceModel

Base class for inference models containing common logic.

__init__(config: InferenceModelConfig) None[源代码]#
apply_chat_template(tokenizer_or_processor, messages: List[dict]) str[源代码]#
async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#

Convert a list of messages into an experience in async.

参数:
  • messages -- List of message dictionaries

  • tools -- Optional list of tools

  • temperature -- Optional temperature for logprobs calculation

class trinity.common.models.model.ModelWrapper(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[源代码]#

基类:object

A wrapper for the InferenceModel Ray Actor

__init__(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[源代码]#

Initialize the ModelWrapper.

参数:
  • model (InferenceModel) -- The inference model Ray actor.

  • enable_lora (bool) -- Whether to enable LoRA. Default to False.

  • enable_history (bool) -- Whether to enable history recording. Default to False.

async prepare() None[源代码]#

Prepare the model wrapper.

generate(*args, **kwargs)[源代码]#
async generate_async(*args, **kwargs)[源代码]#
chat(*args, **kwargs)[源代码]#
async chat_async(*args, **kwargs)[源代码]#
logprobs(tokens: List[int], temperature: float | None = None) Tensor[源代码]#

Calculate the logprobs of the given tokens.

async logprobs_async(tokens: List[int], temperature: float | None = None) Tensor[源代码]#

Calculate the logprobs of the given tokens in async.

convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#

Convert a list of messages into an experience.

async convert_messages_to_experience_async(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#

Convert a list of messages into an experience in async.

property api_key: str#

Get the API key.

property model_version: int#

Get the version of the model.

property model_version_async: int#

Get the version of the model.

property model_path: str#

Returns the path to the model files based on the current engine type.

  • For 'vllm' engine: returns the model path from the configuration (config.model_path)

  • For 'tinker' engine: returns the path to the most recent sampler weights

property model_path_async: str#

Returns the path to the model files based on the current engine type.

  • For 'vllm' engine: returns the model path from the configuration (config.model_path)

  • For 'tinker' engine: returns the path to the most recent sampler weights

property model_name: str | None#

Get the name of the model.

property model_config: InferenceModelConfig#

Get the model config.

property generate_kwargs: Dict[str, Any]#

Get the generation kwargs for openai client.

get_lora_request() Any[源代码]#
async get_lora_request_async() Any[源代码]#
async get_message_token_len(messages: List[dict]) int[源代码]#
get_openai_client() openai.OpenAI[源代码]#

Get the openai client.

返回:

The openai client. And model_path is added to the client which refers to the model path.

返回类型:

openai.OpenAI

get_openai_async_client() openai.AsyncOpenAI[源代码]#

Get the async openai client.

返回:

The async openai client. And model_path is added to the client which refers to the model path.

返回类型:

openai.AsyncOpenAI

async get_current_load() int[源代码]#

Get the current load metrics of the model.

async sync_model_weights(model_version: int) None[源代码]#

Sync the model weights

extract_experience_from_history(clear_history: bool = True) List[Experience][源代码]#

Extract experiences from the history.

async set_workflow_state(state: Dict) None[源代码]#

Set the state of workflow using the model.

async clean_workflow_state() None[源代码]#

Clean the state of workflow using the model.

async get_workflow_state() Dict[源代码]#

Get the state of workflow using the model.

clone_with_isolated_history() ModelWrapper[源代码]#

Clone the current ModelWrapper with isolated history.

trinity.common.models.model.convert_api_output_to_experience(output) List[Experience][源代码]#

Convert non-stream/stream API outputs to a list of experiences.

class trinity.common.models.model.HistoryRecordingStream(stream, history: List[Experience], is_async: bool = False)[源代码]#

基类:object

__init__(stream, history: List[Experience], is_async: bool = False) None[源代码]#
close() None[源代码]#
async aclose() None[源代码]#
trinity.common.models.model.extract_logprobs(choice) Tensor[源代码]#

Extract logprobs from a list of logprob dictionaries.