# -*- coding: utf-8 -*-
"""The Workflow Runner Module."""
import asyncio
import os
import time
import traceback
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from trinity.buffer import get_buffer_reader, get_buffer_writer
from trinity.common.config import Config, StorageConfig
from trinity.common.constants import LOG_DIR_ENV_VAR, LOG_LEVEL_ENV_VAR
from trinity.common.experience import Experience
from trinity.common.models import get_debug_explorer_model
from trinity.common.models.model import InferenceModel, ModelWrapper
from trinity.common.workflows import Task, Workflow
from trinity.utils.log import get_logger
[文档]
@dataclass(frozen=True)
class Status:
"""Status of the task running result."""
completed_runs: int
total_runs: int
metrics: List[Dict[str, float]]
# A list of metric dictionaries, where each dictionary is from a single run.
message: Optional[str] = None
@property
def ok(self) -> bool:
return self.completed_runs == self.total_runs
[文档]
@dataclass(frozen=True)
class RunnerExecutionResult:
"""Execution result for one runner task."""
status: Status
experiences: List[Experience]
[文档]
def calculate_run_level_metrics(experiences: List[Experience]) -> Dict[str, float]:
"""Calculate metrics from experiences.
For non-repeatable workflows, this function will average the metrics from experiences
generated by each run, which is equivalent to calculating run level metrics.
For repeatable workflows, please do not use this function.
"""
run_level_metrics: Dict[str, List[float]] = defaultdict(list)
for exp in experiences:
if exp.metrics:
for k, v in exp.metrics.items():
run_level_metrics[k].append(v)
averaged_metrics: Dict[str, float] = {}
for key, values in run_level_metrics.items():
averaged_metrics[key] = sum(values) / len(values)
return averaged_metrics
[文档]
class WorkflowRunner:
"""A Ray remote actor to run the workflow and generate experiences."""
[文档]
def __init__(
self,
config: Config,
model: InferenceModel,
auxiliary_models: Optional[List[InferenceModel]] = None,
runner_id: Optional[int] = None,
) -> None:
self.name = f"{config.explorer.name}_runner_{runner_id}"
self.logger = get_logger(self.name, in_ray_actor=True)
self.config = config
self.model = model
self.model_wrapper = ModelWrapper(
model,
enable_lora=config.explorer.rollout_model.enable_lora,
enable_history=config.explorer.rollout_model.enable_history,
)
self.auxiliary_models = auxiliary_models or []
self.auxiliary_model_wrappers = [
ModelWrapper(
model,
)
for model, aux_model_config in zip(
self.auxiliary_models, config.explorer.auxiliary_models
)
]
self.workflow_instance: Workflow = None
self.runner_id = runner_id
self.runner_state = {
"workflow_id": None,
"model_version": None,
"begin_time": 0,
"terminate_time": 0,
}
self.concurrent_mode = config.explorer.concurrent_mode
if self.concurrent_mode == "sequential":
self.concurrent_run_fn = self._sequential_run
elif self.concurrent_mode == "asynchronous":
self.concurrent_run_fn = self._asynchronous_run
elif self.concurrent_mode == "multi-threading":
self.concurrent_run_fn = self._multi_threading_run
else:
self.logger.warning(
f"Unknown concurrent_mode {self.concurrent_mode}, defaulting to sequential."
)
self.concurrent_run_fn = self._sequential_run
self.logger.info(
f"WorkflowRunner [{self.name}]({self.concurrent_mode}) initialized:\n"
f" > rollout model: {self.config.explorer.rollout_model.model_path}\n"
f" > auxiliary models: {[aux_model_config.model_path for aux_model_config in self.config.explorer.auxiliary_models]}"
)
[文档]
async def prepare(self) -> None:
"""Prepare the runner."""
await asyncio.gather(
self.model_wrapper.prepare(),
*(aux_model.prepare() for aux_model in self.auxiliary_model_wrappers),
)
self.logger.info(f"WorkflowRunner [{self.name}] is prepared and ready to run tasks.")
[文档]
def is_alive(self):
return True
def _create_workflow_instance(self, task: Task) -> Workflow:
if task.workflow is None:
raise ValueError("Workflow is not set in the task.")
if (
self.workflow_instance is None
or not self.workflow_instance.__class__ == task.workflow
or not self.workflow_instance.resettable
):
# Pass ModelWrapper directly; Workflow.__init__ will get OpenAI clients automatically
self.workflow_instance = task.to_workflow(
self.model_wrapper,
self.auxiliary_model_wrappers,
)
else:
self.workflow_instance.reset(task)
return self.workflow_instance
async def _run_workflow(self, workflow_instance: Workflow) -> List[Experience]:
if workflow_instance.asynchronous:
exps = await workflow_instance.run_async()
else:
exps = workflow_instance.run()
return exps
def _create_isolated_workflow_instance(self, task: Task) -> Workflow:
return task.to_workflow(
self.model_wrapper.clone_with_isolated_history()
if self.config.explorer.rollout_model.enable_history
else self.model_wrapper,
self.auxiliary_model_wrappers,
)
def _build_execution_result(
self,
total_runs: int,
completed_runs: int,
metrics: List[Dict[str, float]],
experiences: List[Experience],
first_error: Optional[str] = None,
) -> RunnerExecutionResult:
if first_error is None:
message = None
elif completed_runs > 0:
message = (
f"{completed_runs}/{total_runs} runs completed successfully. "
f"First error: {first_error}"
)
else:
message = first_error
return RunnerExecutionResult(
status=Status(
completed_runs=completed_runs,
total_runs=total_runs,
metrics=list(metrics),
message=message,
),
experiences=experiences,
)
def _aggregate_run_results(
self,
total_runs: int,
results: List[Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]],
) -> RunnerExecutionResult:
exps = []
run_metrics = []
first_error = None
for ok, new_exps, run_metric, error in results:
if ok:
exps.extend(new_exps)
if run_metric is not None:
run_metrics.append(run_metric)
continue
if first_error is None:
first_error = error
return self._build_execution_result(
total_runs=total_runs,
completed_runs=len(run_metrics),
metrics=run_metrics,
experiences=exps,
first_error=first_error,
)
async def _run_parallel_runs(
self,
task: Task,
repeat_times: int,
run_id_base: int,
collect_partial_runs: bool = True,
use_threads: bool = False,
) -> RunnerExecutionResult:
async def run_single(
i: int,
) -> Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]:
workflow = self._create_isolated_workflow_instance(task)
return await self._execute_single_run(workflow, task, i, run_id_base)
if collect_partial_runs:
if use_threads:
results = await asyncio.gather(
*(
asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc]
for i in range(repeat_times)
)
)
else:
results = await asyncio.gather(*(run_single(i) for i in range(repeat_times)))
return self._aggregate_run_results(repeat_times, results)
future_to_run_index = {}
for i in range(repeat_times):
if use_threads:
future = asyncio.create_task(
asyncio.to_thread(lambda idx=i: asyncio.run(run_single(idx))) # type: ignore[misc]
)
else:
future = asyncio.create_task(run_single(i))
future_to_run_index[future] = i
results = []
while future_to_run_index:
done, pending = await asyncio.wait(
future_to_run_index.keys(),
return_when=asyncio.FIRST_COMPLETED,
)
should_stop = False
for future in done:
future_to_run_index.pop(future)
result = future.result()
results.append(result)
ok, _, _, _ = result
if not ok:
should_stop = True
if should_stop:
for future in pending:
future.cancel()
if pending:
await asyncio.gather(*pending, return_exceptions=True)
break
return self._aggregate_run_results(repeat_times, results)
async def _execute_single_run(
self,
workflow: Workflow,
task: Task,
run_index: int,
run_id_base: int,
) -> Tuple[bool, List[Experience], Optional[Dict[str, float]], Optional[str]]:
st = time.time()
await self.model_wrapper.clean_workflow_state()
self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_index}"
self.runner_state["terminate_time"] = None
self.runner_state["begin_time"] = st
try:
new_exps = await self._run_workflow(workflow)
et = time.time()
self.runner_state["terminate_time"] = et
run_metric = calculate_run_level_metrics(new_exps)
run_metric["time/run_execution"] = et - st
for exp in new_exps:
exp.eid.run = run_id_base + run_index
return True, new_exps, run_metric, None
except Exception as exc:
self.runner_state["terminate_time"] = time.time()
error_trace_back = traceback.format_exc()
self.logger.error(
"WorkflowRunner single run error: " f"{exc}\nTraceback:\n{error_trace_back}"
)
return False, [], None, error_trace_back.rstrip()
async def _run_task(
self,
task: Task,
repeat_times: int,
run_id_base: int,
collect_partial_runs: bool = True,
) -> RunnerExecutionResult:
"""Init workflow from the task and run it."""
if task.workflow.can_repeat:
workflow_instance = self._create_workflow_instance(task)
workflow_instance.set_repeat_times(repeat_times, run_id_base)
st = time.time()
await self.model_wrapper.clean_workflow_state()
self.runner_state["workflow_id"] = f"{task.batch_id}/{task.task_id}/{run_id_base}"
self.runner_state["terminate_time"] = None
self.runner_state["begin_time"] = st
exps = await self._run_workflow(workflow_instance)
et = time.time()
self.runner_state["terminate_time"] = et
# repeatable workflow cannot calculate run level metrics, we use experience level metrics directly
run_metrics = [exp.metrics for exp in exps if exp.metrics]
for metric in run_metrics:
metric["time/run_execution"] = et - st
return self._build_execution_result(
total_runs=repeat_times,
completed_runs=repeat_times,
metrics=run_metrics,
experiences=exps,
)
else:
return await self.concurrent_run_fn(
task,
repeat_times,
run_id_base,
collect_partial_runs=collect_partial_runs,
)
async def _sequential_run(
self,
task: Task,
repeat_times: int,
run_id_base: int,
collect_partial_runs: bool = True,
) -> RunnerExecutionResult:
results = []
for i in range(repeat_times):
workflow = self._create_workflow_instance(task)
result = await self._execute_single_run(workflow, task, i, run_id_base)
results.append(result)
if collect_partial_runs:
continue
ok, _, _, _ = result
if ok:
continue
break
return self._aggregate_run_results(repeat_times, results)
async def _asynchronous_run(
self,
task: Task,
repeat_times: int,
run_id_base: int,
collect_partial_runs: bool = True,
) -> RunnerExecutionResult:
return await self._run_parallel_runs(
task,
repeat_times,
run_id_base,
collect_partial_runs=collect_partial_runs,
)
async def _multi_threading_run(
self,
task: Task,
repeat_times: int,
run_id_base: int,
collect_partial_runs: bool = True,
) -> RunnerExecutionResult:
return await self._run_parallel_runs(
task,
repeat_times,
run_id_base,
collect_partial_runs=collect_partial_runs,
use_threads=True,
)
[文档]
async def get_runner_state(self) -> Dict:
"""Get the runner state."""
runner_state = self.runner_state.copy()
runner_state.update(await self.model_wrapper.get_workflow_state())
return runner_state
[文档]
async def run_task(
self,
task: Task,
batch_id: str,
repeat_times: int = 1,
run_id_base: int = 0,
collect_partial_runs: bool = True,
) -> Tuple[Status, bytes]:
"""Run the task and return the states."""
# TODO: avoid sending the experiences back to the scheduler to reduce the communication overhead
st = time.time()
try:
model_version = await self.model_wrapper.model_version_async
self.runner_state["model_version"] = model_version
self.logger.info(
f"Starting task: step={batch_id}, model_version={model_version}, repeat_times={repeat_times}, run_id_base={run_id_base}"
)
execution_result = await self._run_task(
task,
repeat_times,
run_id_base,
collect_partial_runs=collect_partial_runs,
)
exps = execution_result.experiences
if execution_result.status.completed_runs > 0:
assert exps is not None and len(exps) > 0, "An empty experience is generated"
# set eid for each experience
for exp in exps:
exp.eid.batch = task.batch_id
# keep exp.eid.task if it has been set before (e.g., in workflow)
if exp.eid.task == "": # "" is the default value
exp.eid.task = task.task_id
if not hasattr(exp, "info") or exp.info is None:
exp.info = {}
exp.info["model_version"] = model_version
exp.info["use_count"] = 0
exp.info["task_index"] = task.index
if not hasattr(exp, "metrics") or exp.metrics is None:
exp.metrics = {}
status = execution_result.status
if task.is_eval:
# If the task is an evaluation task, we do not record the experiences to the buffer
return status, b""
else:
exp_payload = Experience.serialize_many(exps)
return status, exp_payload
except Exception as e:
error_trace_back = traceback.format_exc()
self.logger.error(f"WorkflowRunner run task error: {e}\nTraceback:\n{error_trace_back}")
return (
Status(
completed_runs=0,
total_runs=repeat_times,
metrics=[{"time/run_execution": time.time() - st}],
message=error_trace_back.rstrip(),
),
b"",
)
[文档]
class DebugWorkflowRunner(WorkflowRunner):
"""A WorkflowRunner for debugging."""
[文档]
def __init__(
self,
config: Config,
output_dir: str = "debug_output",
enable_profiling: bool = False,
disable_overwrite: bool = False,
) -> None:
model, auxiliary_models = get_debug_explorer_model(config)
if disable_overwrite:
# if output dir is not empty, change to a new dir with datetime suffix
if os.path.isdir(output_dir) and os.listdir(output_dir):
suffix = time.strftime("%Y%m%d%H%M%S", time.localtime())
output_dir = f"{output_dir}_{suffix}"
os.environ[LOG_DIR_ENV_VAR] = os.path.join(output_dir, "log")
os.environ[LOG_LEVEL_ENV_VAR] = "DEBUG"
super().__init__(config, model, auxiliary_models, 0)
self.taskset = get_buffer_reader(config.buffer.explorer_input.tasksets[0])
self.output_dir = output_dir
self.enable_profiling = enable_profiling
self.logger.info(f"Debug output directory: {self.output_dir}")
os.makedirs(self.output_dir, exist_ok=True)
self.output_profiling_file = os.path.join(
self.output_dir,
"profiling.html",
)
self.output_sqlite_file = "sqlite:///" + os.path.join(
self.output_dir,
"experiences.db",
)
self.sqlite_writer = get_buffer_writer(
StorageConfig(
name="debug_buffer",
schema_type="experience",
path=self.output_sqlite_file,
storage_type="sql",
batch_size=1,
wrap_in_ray=False,
)
)
[文档]
async def prepare(self) -> None:
# make sure models are started
prepare_refs = [self.model.prepare.remote()]
prepare_refs.extend(model.prepare.remote() for model in self.auxiliary_models)
await asyncio.gather(*prepare_refs)
await super().prepare()
[文档]
async def debug(self) -> None:
"""Run the debug workflow."""
await self.prepare()
tasks = await self.taskset.read_async(batch_size=1)
task = tasks[0]
self.logger.info(f"Start debugging task:\n{task.raw_task}")
if not self.enable_profiling:
status, exp_payload = await self.run_task(
task=task, batch_id="debug", repeat_times=1, run_id_base=0
)
else:
from viztracer import VizTracer
with VizTracer(output_file=self.output_profiling_file):
status, exp_payload = await self.run_task(
task=task, batch_id="debug", repeat_times=1, run_id_base=0
)
experiences = Experience.deserialize_many(exp_payload) if exp_payload else []
if not status.ok and not experiences:
experiences = self.model_wrapper.extract_experience_from_history()
self.logger.info(
f"Debugging failed, extracting {len(experiences)} experiences from history."
)
await self.sqlite_writer.write_async(experiences)
if status.ok:
print(f"Task {task.task_id} completed successfully with metrics:\n{status.metrics}")
else:
self.logger.error(f"Task {task.task_id} failed with message: {status.message}")
self.logger.info("Debugging completed.")