trinity.buffer.operators.experience_operator 源代码
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, List, Tuple
from trinity.common.config import OperatorConfig
from trinity.common.experience import Experience
if TYPE_CHECKING:
from openai import AsyncOpenAI
[文档]
class ExperienceOperator(ABC):
"""
Base class for all experience operators in the Trinity framework.
Operators are used to process experiences and perform some transformations based on them.
This interface will be deprecated in the future in favor of ExperienceOperatorV1, which supports asynchronous processing and access to auxiliary models.
Do not implement new operators based on this interface. Please use ExperienceOperatorV1 instead.
"""
[文档]
@abstractmethod
def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
"""Process a list of experiences and return a transformed list.
Args:
exps (List[Experience]): List of experiences to process, which contains
all experiences generated by the Explorer in one explore step.
Returns:
Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
"""
[文档]
def close(self):
"""Close the operator if it has any resources to release."""
pass
[文档]
class ExperienceOperatorV1(ABC):
"""An enhanced version of ExperienceOperator that runs asynchronously and has access to auxiliary models."""
[文档]
def set_auxiliary_model(
self, auxiliary_models: Dict[str | int, List["AsyncOpenAI"]] | None = None
) -> None:
"""Set the auxiliary models for the operator."""
self.auxiliary_models = auxiliary_models or {}
[文档]
async def prepare(self) -> None:
"""Prepare the operator if it has any asynchronous initialization."""
pass
[文档]
@abstractmethod
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
"""Process a list of experiences and return a transformed list.
Args:
exps (List[Experience]): List of experiences to process, which contains
all experiences generated by the Explorer in one explore step.
Returns:
Tuple[List[Experience], Dict]: A tuple containing the processed list of experiences and a dictionary of metrics.
"""
[文档]
async def close(self):
"""Close the operator if it has any resources to release."""
pass
[文档]
class ExperienceOperatorV1Wrapper(ExperienceOperatorV1):
"""Adapt a legacy ExperienceOperator to the ExperienceOperatorV1 async interface."""
[文档]
def __init__(self, operator: ExperienceOperator):
self.operator = operator
[文档]
async def process(self, exps: List[Experience]) -> Tuple[List[Experience], Dict]:
return self.operator.process(exps)
[文档]
async def close(self):
return self.operator.close()
[文档]
def ensure_v1_operator(operator: ExperienceOperator | ExperienceOperatorV1) -> ExperienceOperatorV1:
"""Ensure the operator exposes ExperienceOperatorV1 interface."""
if isinstance(operator, ExperienceOperatorV1):
return operator
if isinstance(operator, ExperienceOperator):
return ExperienceOperatorV1Wrapper(operator)
raise TypeError(f"Unsupported operator instance type: {type(operator)}")
[文档]
def create_operators(
operator_configs: List[OperatorConfig],
auxiliary_models: Dict[str | int, List["AsyncOpenAI"]] | None = None,
) -> List[ExperienceOperatorV1]:
"""Create a list of ExperienceOperatorV1 instances based on the provided operator configurations.
Args:
operator_configs (List[OperatorConfig]): List of operator configurations.
auxiliary_models (Dict[str | int, List["AsyncOpenAI"]], optional): A dictionary of auxiliary
models that can be used by the operators. The keys are model identifiers and the values
are lists of openai.AsyncOpenAI instances. Defaults to None.
Returns:
List[ExperienceOperatorV1]: List of instantiated ExperienceOperatorV1 objects.
"""
from trinity.buffer.operators import EXPERIENCE_OPERATORS
operators = []
for config in operator_configs:
operator_class = EXPERIENCE_OPERATORS.get(config.name)
if not operator_class:
raise ValueError(f"Unknown operator: {config.name}")
if not (
issubclass(operator_class, ExperienceOperatorV1)
or issubclass(operator_class, ExperienceOperator)
):
raise ValueError(f"Unknown operator type: {config.name}")
operator = ensure_v1_operator(operator_class(**config.args))
operator.set_auxiliary_model(auxiliary_models)
operators.append(operator)
return operators