trinity.explorer.scheduler module#
Scheduler for rollout tasks.
- class trinity.explorer.scheduler.TaskWrapper(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>)[源代码]#
基类:
objectA wrapper for a task. Each task can run multiple times (repeat_times) on same or different runners.
- batch_id: int | str#
- sub_task_num: int = 1#
- results: List[Tuple[Status, List[Experience]]]#
- __init__(task: ~trinity.common.workflows.workflow.Task, batch_id: int | str, sub_task_num: int = 1, results: ~typing.List[~typing.Tuple[~trinity.explorer.workflow_runner.Status, ~typing.List[~trinity.common.experience.Experience]]] = <factory>) None#
- trinity.explorer.scheduler.bootstrap_metric(data: list[Any], subset_size: int, reduce_fns: list[Callable[[ndarray], float]], n_bootstrap: int = 1000, seed: int = 42) list[tuple[float, float]][源代码]#
Performs bootstrap resampling to estimate statistics of metrics.
This function uses bootstrap resampling to estimate the mean and standard deviation of metrics computed by the provided reduction functions on random subsets of the data.
- 参数:
data -- List of data points to bootstrap from.
subset_size -- Size of each bootstrap sample.
reduce_fns -- List of functions that compute a metric from a subset of data.
n_bootstrap -- Number of bootstrap iterations. Defaults to 1000.
seed -- Random seed for reproducibility. Defaults to 42.
- 返回:
A list of tuples, where each tuple contains (mean, std) for a metric corresponding to each reduction function in reduce_fns.
示例
>>> data = [1, 2, 3, 4, 5] >>> reduce_fns = [np.mean, np.max] >>> bootstrap_metric(data, 3, reduce_fns) [(3.0, 0.5), (4.5, 0.3)] # Example values
- trinity.explorer.scheduler.calculate_task_level_metrics(metrics: List[Dict], is_eval: bool) Dict[str, float][源代码]#
Calculate task level metrics (mean) from multiple runs of the same task.
- 参数:
metrics (List[Dict]) -- A list of metric dictionaries from multiple runs of the same task.
is_eval (bool) -- Whether this is an evaluation task.
- 返回:
A dictionary of aggregated metrics, where each metric is averaged over all runs.
- 返回类型:
Dict[str, float]
- class trinity.explorer.scheduler.RunnerWrapper(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[源代码]#
基类:
objectA wrapper for a WorkflowRunner
- __init__(runner_id: int, rollout_model: InferenceModel, auxiliary_models: List[InferenceModel], config: Config)[源代码]#
- async run_with_retry(task: TaskWrapper, repeat_times: int, run_id_base: int, timeout: float) Tuple[Status, List, int, float][源代码]#
- 参数:
task (TaskWrapper) -- The task to run.
repeat_times (int) -- The number of times to repeat the task.
run_id_base (int) -- The base run id for this task runs.
timeout (float) -- The timeout for each task run.
- 返回:
The return status of the task. List: The experiences generated by the task. int: The runner_id of current runner. float: The time taken to run the task.
- 返回类型:
Status
- class trinity.explorer.scheduler.Scheduler(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[源代码]#
基类:
objectScheduler for rollout tasks.
Supports scheduling tasks to multiple runners, retrying failed tasks, and collecting results at different levels.
- __init__(config: Config, rollout_model: List[InferenceModel], auxiliary_models: List[List[InferenceModel]] | None = None)[源代码]#
- schedule(tasks: List[Task], batch_id: int | str) None[源代码]#
Schedule the provided tasks.
- 参数:
tasks (List[Task]) -- The tasks to schedule.
batch_id (Union[int, str]) -- The id of provided tasks. In most cases, it should be current step number for training tasks and "<current_step_num>/<eval_taskset_name>" for eval tasks.
- dynamic_timeout(timeout: float | None = None) float[源代码]#
Calculate dynamic timeout based on historical data.
- async get_results(batch_id: int | str, min_num: int | None = None, timeout: float | None = None, clear_timeout_tasks: bool = True) Tuple[List[Status], List[Experience]][源代码]#
Get the result of tasks at the specific batch_id.
- 参数:
batch_id (Union[int, str]) -- Only wait for tasks at this batch.
min_num (int) -- The minimum number of tasks to wait for. If None, wait for all tasks at batch_id.
timeout (float) -- The timeout for waiting for tasks to finish. If None, wait for default timeout.
clear_timeout_tasks (bool) -- Whether to clear timeout tasks.
- async wait_all(timeout: float | None = None, clear_timeout_tasks: bool = True) None[源代码]#
Wait for all tasks to complete without poping results. If timeout reached, raise TimeoutError.
- 参数:
timeout (float) -- timeout in seconds. Raise TimeoutError when no new tasks is completed within timeout.
clear_timeout_tasks (bool) -- Whether to clear timeout tasks.
- get_key_state(key: str) Dict[源代码]#
Get the scheduler state.
- 参数:
key (str) -- The key of the state to get.
- 返回:
A dictionary of runner ids to their state for the given key.
- 返回类型:
Dict