trinity.explorer.workflow_runner 源代码

# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import asyncio
import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

from trinity.buffer import get_buffer_reader, get_buffer_writer
from trinity.common.config import Config, StorageConfig
from trinity.common.constants import LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR
from trinity.common.experience import Experience
from trinity.common.models import get_debug_explorer_model
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task, Workflow
from trinity.utils.log import get_logger


[文档] @dataclass(frozen=True) class Status: """Status of the task running result.""" completed_runs: int total_runs: int metrics: List[Dict[str, float]] # A list of metric dictionaries, where each dictionary is from a single run. message: Optional[str] = None @property def ok(self) -> bool: return self.completed_runs == self.total_runs
[文档] @dataclass(frozen=True) class RunnerExecutionResult: """Execution result for one runner task.""" status: Status experiences: List[Experience]
[文档] def calculate_run_level_metrics(experiences: List[Experience]) -> Dict[str, float]: """Calculate metrics from experiences. For non-repeatable workflows, this function will average the metrics from experiences generated by each run, which is equivalent to calculating run level metrics. For repeatable workflows, please do not use this function. """ run_level_metrics: Dict[str, List[float]] = defaultdict(list) for exp in experiences: if exp.metrics: for k, v in exp.metrics.items(): run_level_metrics[k].append(v) averaged_metrics: Dict[str, float] = {} for key, values in run_level_metrics.items(): averaged_metrics[key] = sum(values) / len(values) return averaged_metrics
[文档] class WorkflowRunner: """A Ray remote actor to run the workflow and generate experiences."""
[文档] def __init__( self, config: Config, model: InferenceModel, auxiliary_models: Optional[List[InferenceModel]] = None, runner_id: Optional[int] = None, ) -> None: self.name = f"{config.explorer.name}_runner_{runner_id}" self.logger = get_logger(self.name, in_ray_actor=True) self.config = config self.model = model self.model_wrapper = ModelWrapper( model, enable_lora=config.explorer.rollout_model.enable_lora, enable_history=config.explorer.rollout_model.enable_history, ) self.auxiliary_models = auxiliary_models or [] self.auxiliary_model_wrappers = [ ModelWrapper( model, ) for model, aux_model_config in zip( self.auxiliary_models, config.explorer.auxiliary_models ) ] self.workflow_instance: Workflow = None self.runner_id = runner_id self.runner_state = { "workflow_id": None, "model_version": None, "begin_time": 0, "terminate_time": 0, } self.concurrent_mode = config.explorer.concurrent_mode if self.concurrent_mode == "sequential": self.concurrent_run_fn = self._sequential_run elif self.concurrent_mode == "asynchronous": self.concurrent_run_fn = self._asynchronous_run elif self.concurrent_mode == "multi-threading": self.concurrent_run_fn = self._multi_threading_run else: self.logger.warning( f"Unknown concurrent_mode {self.concurrent_mode}, defaulting to sequential." ) self.concurrent_run_fn = self._sequential_run self.logger.info( f"WorkflowRunner [{self.name}]({self.concurrent_mode}) initialized:\n" f" > rollout model: {self.config.explorer.rollout_model.model_path}\n" f" > auxiliary models: {[aux_model_config.model_path for aux_model_config in self.config.explorer.auxiliary_models]}" )
[文档] async def prepare(self) -> None: """Prepare the runner.""" await asyncio.gather( self.model_wrapper.prepare(), *(aux_model.prepare() for aux_model in self.auxiliary_model_wrappers), ) self.logger.info(f"WorkflowRunner [{self.name}] is prepared and ready to run tasks.")
[文档] def is_alive(self): return True
def _create_workflow_instance(self, task: Task) -> Workflow: if task.workflow is None: raise ValueError("Workflow is not set in the task.") if ( self.workflow_instance is None or not self.workflow_instance.__class__ == task.workflow or not self.workflow_instance.resettable ): # Pass ModelWrapper directly; Workflow.__init__ will get OpenAI clients automatically self.workflow_instance = task.to_workflow( self.model_wrapper, self.auxiliary_model_wrappers, ) else: self.workflow_instance.reset(task) return self.workflow_instance async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]: if workflow_instance.asynchronous: exps = await workflow_instance.run_async() else: exps = workflow_instance.run() return exps def _create_isolated_workflow_instance(self, task: Task) -> Workflow: return task.to_workflow( self.model_wrapper.clone_with_isolated_history() if self.config.explorer.rollout_model.enable_history else self.model_wrapper, self.auxiliary_model_wrappers, ) def _build_execution_result( self, total_runs: int, completed_runs: int, metrics: List[Dict[str, float]], experiences: List[Experience], first_error: Optional[str] = None, ) -> RunnerExecutionResult: if first_error is None: message = None elif completed_runs > 0: message = ( f"{completed_runs}/{total_runs} runs completed successfully. " f"First error: {first_error}" ) else: message = first_error return RunnerExecutionResult( status=Status( completed_runs=completed_runs, total_runs=total_runs, metrics=list(metrics), message=message, ), experiences=experiences, ) def _aggregate_run_results( self, total_runs: int, results: List[Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]], ) -> RunnerExecutionResult: exps = [] run_metrics = [] first_error = None for ok, new_exps, run_metric, error in results: if ok: exps.extend(new_exps) if run_metric is not None: run_metrics.append(run_metric) continue if first_error is None: first_error = error return self._build_execution_result( total_runs=total_runs, completed_runs=len(run_metrics), metrics=run_metrics, experiences=exps, first_error=first_error, ) async def _run_parallel_runs( self, task: Task, repeat_times: int, run_id_base: int, collect_partial_runs: bool = True, use_threads: bool = False, ) -> RunnerExecutionResult: async def run_single( i: int, ) -> Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]: workflow = self._create_isolated_workflow_instance(task) return await self._execute_single_run(workflow, task, i, run_id_base) if collect_partial_runs: if use_threads: results = await asyncio.gather( *( asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc] for i in range(repeat_times) ) ) else: results = await asyncio.gather(*(run_single(i) for i in range(repeat_times))) return self._aggregate_run_results(repeat_times, results) future_to_run_index = {} for i in range(repeat_times): if use_threads: future = asyncio.create_task( asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc] ) else: future = asyncio.create_task(run_single(i)) future_to_run_index[future] = i results = [] while future_to_run_index: done, pending = await asyncio.wait( future_to_run_index.keys(), return_when=asyncio.FIRST_COMPLETED, ) should_stop = False for future in done: future_to_run_index.pop(future) result = future.result() results.append(result) ok, _, _, _ = result if not ok: should_stop = True if should_stop: for future in pending: future.cancel() if pending: await asyncio.gather(*pending, return_exceptions=True) break return self._aggregate_run_results(repeat_times, results) async def _execute_single_run( self, workflow: Workflow, task: Task, run_index: int, run_id_base: int, ) -> Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]: st = time.time() await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_index}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st try: new_exps = await self._run_workflow(workflow) et = time.time() self.runner_state["terminate_time"] = et run_metric = calculate_run_level_metrics(new_exps) run_metric["time/run_execution"] = et - st for exp in new_exps: exp.eid.run = run_id_base + run_index return True, new_exps, run_metric, None except Exception as exc: self.runner_state["terminate_time"] = time.time() error_trace_back = traceback.format_exc() self.logger.error( "WorkflowRunner single run error: " f"{exc}\nTraceback:\n{error_trace_back}" ) return False, [], None, error_trace_back.rstrip() async def _run_task( self, task: Task, repeat_times: int, run_id_base: int, collect_partial_runs: bool = True, ) -> RunnerExecutionResult: """Init workflow from the task and run it.""" if task.workflow.can_repeat: workflow_instance = self._create_workflow_instance(task) workflow_instance.set_repeat_times(repeat_times, run_id_base) st = time.time() await self.model_wrapper.clean_workflow_state() self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}" self.runner_state["terminate_time"] = None self.runner_state["begin_time"] = st exps = await self._run_workflow(workflow_instance) et = time.time() self.runner_state["terminate_time"] = et # repeatable workflow cannot calculate run level metrics, we use experience level metrics directly run_metrics = [exp.metrics for exp in exps if exp.metrics] for metric in run_metrics: metric["time/run_execution"] = et - st return self._build_execution_result( total_runs=repeat_times, completed_runs=repeat_times, metrics=run_metrics, experiences=exps, ) else: return await self.concurrent_run_fn( task, repeat_times, run_id_base, collect_partial_runs=collect_partial_runs, ) async def _sequential_run( self, task: Task, repeat_times: int, run_id_base: int, collect_partial_runs: bool = True, ) -> RunnerExecutionResult: results = [] for i in range(repeat_times): workflow = self._create_workflow_instance(task) result = await self._execute_single_run(workflow, task, i, run_id_base) results.append(result) if collect_partial_runs: continue ok, _, _, _ = result if ok: continue break return self._aggregate_run_results(repeat_times, results) async def _asynchronous_run( self, task: Task, repeat_times: int, run_id_base: int, collect_partial_runs: bool = True, ) -> RunnerExecutionResult: return await self._run_parallel_runs( task, repeat_times, run_id_base, collect_partial_runs=collect_partial_runs, ) async def _multi_threading_run( self, task: Task, repeat_times: int, run_id_base: int, collect_partial_runs: bool = True, ) -> RunnerExecutionResult: return await self._run_parallel_runs( task, repeat_times, run_id_base, collect_partial_runs=collect_partial_runs, use_threads=True, )
[文档] async def get_runner_state(self) -> Dict: """Get the runner state.""" runner_state = self.runner_state.copy() runner_state.update(await self.model_wrapper.get_workflow_state()) return runner_state
[文档] async def run_task( self, task: Task, batch_id: str, repeat_times: int = 1, run_id_base: int = 0, collect_partial_runs: bool = True, ) -> Tuple[Status, bytes]: """Run the task and return the states.""" # TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead st = time.time() try: model_version = await self.model_wrapper.model_version_async self.runner_state["model_version"] = model_version self.logger.info( f"Starting task: step={batch_id}, model_version={model_version}, repeat_times={repeat_times}, run_id_base={run_id_base}" ) execution_result = await self._run_task( task, repeat_times, run_id_base, collect_partial_runs=collect_partial_runs, ) exps = execution_result.experiences if execution_result.status.completed_runs > 0: assert exps is not None and len(exps) > 0, "An empty experience is generated" # set eid for each experience for exp in exps: exp.eid.batch = task.batch_id # keep exp.eid.task if it has been set before (e.g., in workflow) if exp.eid.task == "": # "" is the default value exp.eid.task = task.task_id if not hasattr(exp, "info") or exp.info is None: exp.info = {} exp.info["model_version"] = model_version exp.info["use_count"] = 0 exp.info["task_index"] = task.index if not hasattr(exp, "metrics") or exp.metrics is None: exp.metrics = {} status = execution_result.status if task.is_eval: # If the task is an evaluation task, we do not record the experiences to the buffer return status, b"" else: exp_payload = Experience.serialize_many(exps) return status, exp_payload except Exception as e: error_trace_back = traceback.format_exc() self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}") return ( Status( completed_runs=0, total_runs=repeat_times, metrics=[{"time/run_execution": time.time() - st}], message=error_trace_back.rstrip(), ), b"", )
[文档] class DebugWorkflowRunner(WorkflowRunner): """A WorkflowRunner for debugging."""
[文档] def __init__( self, config: Config, output_dir: str = "debug_output", enable_profiling: bool = False, disable_overwrite: bool = False, ) -> None: model, auxiliary_models = get_debug_explorer_model(config) if disable_overwrite: # if output dir is not empty, change to a new dir with datetime suffix if os.path.isdir(output_dir) and os.listdir(output_dir): suffix = time.strftime("%Y%m%d%H%M%S", time.localtime()) output_dir = f"{output_dir}_{suffix}" os.environ[LOG_DIR_ENV_VAR] = os.path.join(output_dir, "log") os.environ[LOG_LEVEL_ENV_VAR] = "DEBUG" super().__init__(config, model, auxiliary_models, 0) self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0]) self.output_dir = output_dir self.enable_profiling = enable_profiling self.logger.info(f"Debug output directory: {self.output_dir}") os.makedirs(self.output_dir, exist_ok=True) self.output_profiling_file = os.path.join( self.output_dir, "profiling.html", ) self.output_sqlite_file = "sqlite:///" + os.path.join( self.output_dir, "experiences.db", ) self.sqlite_writer = get_buffer_writer( StorageConfig( name="debug_buffer", schema_type="experience", path=self.output_sqlite_file, storage_type="sql", batch_size=1, wrap_in_ray=False, ) )
[文档] async def prepare(self) -> None: # make sure models are started prepare_refs = [self.model.prepare.remote()] prepare_refs.extend(model.prepare.remote() for model in self.auxiliary_models) await asyncio.gather(*prepare_refs) await super().prepare()
[文档] async def debug(self) -> None: """Run the debug workflow.""" await self.prepare() tasks = await self.taskset.read_async(batch_size=1) task = tasks[0] self.logger.info(f"Start debugging task:\n{task.raw_task}") if not self.enable_profiling: status, exp_payload = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) else: from viztracer import VizTracer with VizTracer(output_file=self.output_profiling_file): status, exp_payload = await self.run_task( task=task, batch_id="debug", repeat_times=1, run_id_base=0 ) experiences = Experience.deserialize_many(exp_payload) if exp_payload else [] if not status.ok and not experiences: experiences = self.model_wrapper.extract_experience_from_history() self.logger.info( f"Debugging failed, extracting {len(experiences)} experiences from history." ) await self.sqlite_writer.write_async(experiences) if status.ok: print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}") else: self.logger.error(f"Task {task.task_id} failed with message: {status.message}") self.logger.info("Debugging completed.")