trinity.explorer.scheduler module#

Scheduler for rollout tasks.

class trinity.explorer.scheduler.TaskWrapper(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, finished_sub_task_num: int = 0, completed_runs: int = 0, total_runs: int = 0, metrics: ~typing.List[~typing.Dict[str, float]] = <factory>, experience_payloads: ~typing.List[bytes] = <factory>, first_error: str | None = None, emitted: bool = False)[源代码]#

基类:object

A wrapper for a task. Each task can run multiple times (repeat_times) on same or different runners.

task: Task#
batch_id: int | str#
sub_task_num: int = 1#
finished_sub_task_num: int = 0#
completed_runs: int = 0#
total_runs: int = 0#
metrics: List[Dict[str, float]]#
experience_payloads: List[bytes]#
first_error: str | None = None#
emitted: bool = False#
__init__(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, finished_sub_task_num: int = 0, completed_runs: int = 0, total_runs: int = 0, metrics: ~typing.List[~typing.Dict[str, float]] = <factory>, experience_payloads: ~typing.List[bytes] = <factory>, first_error: str | None = None, emitted: bool = False) None#
class trinity.explorer.scheduler.CompletedTaskResult(batch_id: int | str, task_id: int | str, status: ~trinity.explorer.workflow_runner.Status, experience_payloads: ~typing.List[bytes] = <factory>)[源代码]#

基类:object

A completed task result stored by batch and task id.

batch_id: int | str#
task_id: int | str#
status: Status#
experience_payloads: List[bytes]#
__init__(batch_id: int | str, task_id: int | str, status: ~trinity.explorer.workflow_runner.Status, experience_payloads: ~typing.List[bytes] = <factory>) None#
class trinity.explorer.scheduler.RunningTaskState(task: TaskWrapper, runner_id: int, restart_runner_on_cancel: bool = True)[源代码]#

基类:object

Per-future execution state tracked while one task is running.

task: TaskWrapper#
runner_id: int#
restart_runner_on_cancel: bool = True#
__init__(task: TaskWrapper, runner_id: int, restart_runner_on_cancel: bool = True) None#
trinity.explorer.scheduler.bootstrap_metric(data: list[Any], subset_size: int, reduce_fns: list[Callable[[ndarray], float]], n_bootstrap: int = 1000, seed: int = 42) list[tuple[float, float]][源代码]#

Performs bootstrap resampling to estimate statistics of metrics.

This function uses bootstrap resampling to estimate the mean and standard deviation of metrics computed by the provided reduction functions on random subsets of the data.

参数:
  • data -- List of data points to bootstrap from.

  • subset_size -- Size of each bootstrap sample.

  • reduce_fns -- List of functions that compute a metric from a subset of data.

  • n_bootstrap -- Number of bootstrap iterations. Defaults to 1000.

  • seed -- Random seed for reproducibility. Defaults to 42.

返回:

A list of tuples, where each tuple contains (mean, std) for a metric corresponding to each reduction function in reduce_fns.

示例

>>> data = [1, 2, 3, 4, 5]
>>> reduce_fns = [np.mean, np.max]
>>> bootstrap_metric(data, 3, reduce_fns)
[(3.0, 0.5), (4.5, 0.3)]  # Example values
trinity.explorer.scheduler.calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) Dict[str, float][源代码]#

Calculate task level metrics (mean) from multiple runs of the same task.

参数:
  • metrics (List[Dict]) -- A list of metric dictionaries from multiple runs of the same task.

  • is_eval (bool) -- Whether this is an evaluation task.

返回:

A dictionary of aggregated metrics, where each metric is averaged over all runs.

返回类型:

Dict[str, float]

class trinity.explorer.scheduler.RunnerWrapper(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[源代码]#

基类:object

A wrapper for a WorkflowRunner

__init__(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[源代码]#
async prepare()[源代码]#
async update_state() None[源代码]#

Get the runner state.

async run_with_retry(task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float, collect_partial_runs: bool) Tuple[Status, bytes, int, float][源代码]#
参数:
  • task (TaskWrapper) -- The task to run.

  • repeat_times (int) -- The number of times to repeat the task.

  • run_id_base (int) -- The base run id for this task runs.

  • timeout (float) -- The timeout for each task run.

返回:

The return status of the task. List: The experiences generated by the task. int: The runner_id of current runner. float: The time taken to run the task.

返回类型:

Status

async restart_runner()[源代码]#
trinity.explorer.scheduler.sort_batch_id(batch_id: int | str)[源代码]#

Priority of batch_id

class trinity.explorer.scheduler.Scheduler(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[源代码]#

基类:object

Scheduler for rollout tasks.

Supports scheduling tasks to multiple runners, retrying failed tasks, and collecting results at different levels.

__init__(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[源代码]#
task_done_callback(async_task: Task)[源代码]#
discard_completed_results(batch_id: int | str) None[源代码]#

Drop cached completed results for one batch.

async start() None[源代码]#
async stop() None[源代码]#
schedule(tasks: List[Task], batch_id: int | str) None[源代码]#

Schedule the provided tasks.

参数:
  • tasks (List[Task]) -- The tasks to schedule.

  • batch_id (Union[int, str]) -- The id of provided tasks. In most cases, it should be current step number for training tasks and "<current_step_num>/<eval_taskset_name>" for eval tasks.

dynamic_timeout(timeout: float | None = None) float[源代码]#

Calculate dynamic timeout based on historical data.

async drain_batch_payload_results(batch_id: int | str) Tuple[List[Status], List[bytes]][源代码]#

Drain cached completed results for one batch.

async get_payload_results(batch_id: int | str, min_num: int | None = None, timeout: float | None = None, clear_timeout_tasks: bool = True, return_partial_tasks: bool = False) Tuple[List[Status], List[bytes]][源代码]#

Wait for one batch and return task statuses plus serialized payload chunks.

async get_statuses(batch_id: int | str, min_num: int | None = None, timeout: float | None = None, clear_timeout_tasks: bool = True, return_partial_tasks: bool = False) List[Status][源代码]#

Wait for one batch and return only task statuses without materializing experiences.

async abort_batch(batch_id: int | str, return_partial_tasks: bool = False, restart_runners: bool = True) None[源代码]#

Abort one batch and cleanup unfinished scheduler state.

has_step(batch_id: int | str) bool[源代码]#
async wait_all(timeout: float | None = None, clear_timeout_tasks: bool = True) None[源代码]#

Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.

参数:
  • timeout (float) -- timeout in seconds. Raise TimeoutError when no new tasks is completed within timeout.

  • clear_timeout_tasks (bool) -- Whether to clear timeout tasks.

get_key_state(key: str) Dict[源代码]#

Get the scheduler state.

参数:

key (str) -- The key of the state to get.

返回:

A dictionary of runner ids to their state for the given key.

返回类型:

Dict

get_runner_state(runner_id: int) Dict[源代码]#

Get the scheduler state.

参数:

runner_id (int) -- The id of the runner.

返回:

The state of the runner.

返回类型:

Dict

get_all_state() Dict[源代码]#

Get all runners' state.

返回:

The state of all runners.

返回类型:

Dict

print_all_state() None[源代码]#

Print all runners' state in a clear, aligned table format.