trinity.common.models.external_model 源代码

import os
from typing import Dict, List, Optional, Sequence, Union

import torch

from trinity.common.config import InferenceModelConfig
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel


[文档] class ExternalModel(InferenceModel): """Inference model backed by an external OpenAI-compatible API."""
[文档] def __init__(self, config: InferenceModelConfig) -> None: super().__init__(config) self.model_version = 0 self.client = None self.request_count = 0 self.api_base_url = os.getenv(config.external_model_config.base_url_env, "").rstrip("/") self.api_model_name = config.external_model_config.model_name or config.model_path if self.api_model_name is None: raise ValueError("`api_model_name` or `model_path` must be provided for openai_api.")
[文档] async def prepare(self) -> None: if self.client is not None: return import openai self.client = openai.AsyncOpenAI( base_url=self.api_base_url, api_key=self.get_api_key(), ) self.logger.info( "Initialized external model engine with base_url=%s, model_name=%s, api_key_env=%s", self.api_base_url, self.api_model_name, self.config.external_model_config.api_key_env, )
def _build_experience( self, response_text: str, prompt_text: str = "", reward: Optional[float] = 0.0, metrics: Optional[dict[str, float]] = None, info: Optional[dict] = None, ) -> Experience: # Keep a minimal valid token tensor so existing pipelines can process single-turn data. tokens = torch.tensor([0, 0], dtype=torch.int32) logprobs = torch.tensor([0.0], dtype=torch.float32) return Experience( tokens=tokens, logprobs=logprobs, prompt_length=1, prompt_text=prompt_text, response_text=response_text, reward=reward, metrics=metrics or {}, info=info or {}, ) async def _request_chat_completion( self, messages: List[Dict], **kwargs ) -> Sequence[Experience]: await self.prepare() assert self.client is not None self.request_count += 1 request_id = self.request_count max_completion_tokens = kwargs.get("max_completion_tokens") if max_completion_tokens is None: max_completion_tokens = kwargs.get("max_tokens", self.config.max_response_tokens) req_kwargs = { "model": self.api_model_name, "messages": messages, "temperature": kwargs.get("temperature", self.config.temperature), "max_completion_tokens": max_completion_tokens, "n": kwargs.get("n", 1), } self.logger.debug( "[openai_api][request=%d] model=%s max_completion_tokens=%s temperature=%s top_p=%s", request_id, req_kwargs["model"], req_kwargs["max_completion_tokens"], req_kwargs["temperature"], ) response = await self.client.chat.completions.create(**req_kwargs) usage_metrics = {} usage = getattr(response, "usage", None) if usage is not None: for usage_key in ("prompt_tokens", "completion_tokens", "total_tokens"): usage_val = getattr(usage, usage_key, None) if isinstance(usage_val, (int, float)): usage_metrics[f"usage/{usage_key}"] = float(usage_val) exps = [] for choice in response.choices: content = choice.message.content or "" exps.append( self._build_experience( response_text=content, reward=0.0, metrics=usage_metrics.copy(), info={ "finish_reason": choice.finish_reason, "choice_index": choice.index, }, ) ) return exps
[文档] async def chat(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: return await self.generate(messages, **kwargs)
[文档] async def chat_async(self, messages: List[Dict], **kwargs) -> Sequence[Experience]: return await self.generate(messages, **kwargs)
[文档] async def generate_async( self, prompt: Union[str, List[Dict]], **kwargs ) -> Sequence[Experience]: return await self.generate(prompt, **kwargs)
[文档] async def generate(self, prompt: Union[str, List[Dict]], **kwargs) -> Sequence[Experience]: if isinstance(prompt, str): messages = [{"role": "user", "content": prompt}] elif isinstance(prompt, list): messages = prompt else: raise TypeError(f"Unsupported prompt type: {type(prompt)}") return await self._request_chat_completion(messages, **kwargs)
[文档] async def logprobs(self, token_ids: List[int], **kwargs) -> torch.Tensor: raise NotImplementedError("logprobs for external APIs is not implemented.")
[文档] async def convert_messages_to_experience( self, messages: List[dict], tools: Optional[List[dict]] = None, temperature: Optional[float] = None, ) -> Experience: del temperature if not messages: raise ValueError("`messages` must not be empty.") response_text = "" last = messages[-1] if last.get("role") == "assistant": content = last.get("content") response_text = content if isinstance(content, str) else "" exp = self._build_experience(response_text=response_text) exp.messages = messages exp.tools = tools return exp
[文档] async def sync_model(self, model_version: int) -> int: # for self.model_version = model_version return self.model_version
[文档] def get_model_version(self) -> int: return self.model_version
[文档] def get_api_key(self) -> str: return os.getenv(self.config.external_model_config.api_key_env, "EMPTY")
[文档] def get_api_server_url(self) -> Optional[str]: # ModelWrapper appends `/v1` when building openai client base_url. if self.api_base_url.endswith("/v1"): return self.api_base_url[: -len("/v1")] elif self.api_base_url.endswith("/v1/"): return self.api_base_url[: -len("/v1/")] return self.api_base_url