trinity.common.models.model module#
Base Model Class
- class trinity.common.models.model.InferenceModel(config: InferenceModelConfig)[源代码]#
基类:
ABCA model for high performance for rollout inference.
- __init__(config: InferenceModelConfig) None[源代码]#
- async generate(prompt: str, **kwargs) Sequence[Experience][源代码]#
Generate a responses from a prompt in async.
- async chat(messages: List[dict], **kwargs) Sequence[Experience][源代码]#
Generate experiences from a list of history chat messages in async.
- async logprobs(token_ids: List[int], **kwargs) Tensor[源代码]#
Generate logprobs for a list of tokens in async.
- async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#
Convert a list of messages into an experience in async.
- abstractmethod async sync_model(model_version: int) int[源代码]#
Sync the model with the latest model_version.
- get_model_config() InferenceModelConfig[源代码]#
Get the model configuration.
- class trinity.common.models.model.BaseInferenceModel(config: InferenceModelConfig)[源代码]#
-
Base class for inference models containing common logic.
- __init__(config: InferenceModelConfig) None[源代码]#
- async convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#
Convert a list of messages into an experience in async.
- 参数:
messages -- List of message dictionaries
tools -- Optional list of tools
temperature -- Optional temperature for logprobs calculation
- class trinity.common.models.model.ModelWrapper(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[源代码]#
基类:
objectA wrapper for the InferenceModel Ray Actor
- __init__(model: InferenceModel, enable_lora: bool = False, enable_history: bool = False)[源代码]#
Initialize the ModelWrapper.
- 参数:
model (InferenceModel) -- The inference model Ray actor.
enable_lora (bool) -- Whether to enable LoRA. Default to False.
enable_history (bool) -- Whether to enable history recording. Default to False.
- logprobs(tokens: List[int], temperature: float | None = None) Tensor[源代码]#
Calculate the logprobs of the given tokens.
- async logprobs_async(tokens: List[int], temperature: float | None = None) Tensor[源代码]#
Calculate the logprobs of the given tokens in async.
- convert_messages_to_experience(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#
Convert a list of messages into an experience.
- async convert_messages_to_experience_async(messages: List[dict], tools: List[dict] | None = None, temperature: float | None = None) Experience[源代码]#
Convert a list of messages into an experience in async.
- property api_key: str#
Get the API key.
- property model_version: int#
Get the version of the model.
- property model_version_async: int#
Get the version of the model.
- property model_path: str#
Returns the path to the model files based on the current engine type.
For 'vllm' engine: returns the model path from the configuration (config.model_path)
For 'tinker' engine: returns the path to the most recent sampler weights
- property model_path_async: str#
Returns the path to the model files based on the current engine type.
For 'vllm' engine: returns the model path from the configuration (config.model_path)
For 'tinker' engine: returns the path to the most recent sampler weights
- property model_name: str | None#
Get the name of the model.
- property model_config: InferenceModelConfig#
Get the model config.
- property generate_kwargs: Dict[str, Any]#
Get the generation kwargs for openai client.
- get_openai_client() openai.OpenAI[源代码]#
Get the openai client.
- 返回:
The openai client. And model_path is added to the client which refers to the model path.
- 返回类型:
openai.OpenAI
- get_openai_async_client() openai.AsyncOpenAI[源代码]#
Get the async openai client.
- 返回:
The async openai client. And model_path is added to the client which refers to the model path.
- 返回类型:
openai.AsyncOpenAI
- extract_experience_from_history(clear_history: bool = True) List[Experience][源代码]#
Extract experiences from the history.
- clone_with_isolated_history() ModelWrapper[源代码]#
Clone the current ModelWrapper with isolated history.
- trinity.common.models.model.convert_api_output_to_experience(output) List[Experience][源代码]#
Convert non-stream/stream API outputs to a list of experiences.
- class trinity.common.models.model.HistoryRecordingStream(stream, history: List[Experience], is_async: bool = False)[源代码]#
基类:
object- __init__(stream, history: List[Experience], is_async: bool = False) None[源代码]#