# -*- coding: utf-8 -*-
"""The explorer module"""
from __future__ import annotations
import asyncio
import math
import os
import time
import traceback
from collections import deque
from typing import List, Optional
import ray
import torch
from trinity.buffer.buffer import get_buffer_reader
from trinity.buffer.task_scheduler import get_taskset_scheduler
from trinity.common.config import Config
from trinity.common.constants import (
ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
RunningStatus,
SyncMethod,
SyncStyle,
)
from trinity.common.models import create_explorer_models
from trinity.explorer.rollout_coordinator import RolloutCoordinator
from trinity.manager.state_manager import StateManager
from trinity.manager.synchronizer import Synchronizer
from trinity.utils.annotations import Experimental
from trinity.utils.log import get_logger
from trinity.utils.monitor import MONITOR
from trinity.utils.plugin_loader import load_plugins
from trinity.utils.timer import Timer
[文档]
class Explorer:
"""Responsible for exploring the taskset."""
[文档]
def __init__(self, config: Config):
self.logger = get_logger(config.explorer.name, in_ray_actor=True)
load_plugins()
self.state = StateManager(
path=config.checkpoint_job_dir, explorer_name=config.explorer.name, config=config
)
explorer_state = self.state.load_explorer()
self.explore_step_num = explorer_state.get("latest_iteration", 0)
self.last_monitored_step = self.explore_step_num
self.synchronizer = Synchronizer.get_actor(config)
self.config = config
self.model_type = config.explorer.rollout_model.engine_type
self.models, self.auxiliary_models = create_explorer_models(config)
self.taskset = (
get_taskset_scheduler(explorer_state=explorer_state, config=config)
if self.config.mode not in {"bench", "serve"}
else None
)
self.monitor = MONITOR.get(self.config.monitor.monitor_type)(
project=self.config.project,
group=self.config.group,
name=self.config.name,
role=self.config.explorer.name,
config=config,
)
self.detailed_stats = config.monitor.detailed_stats
if config.explorer.over_rollout.ratio > 0.0:
self.min_wait_num = math.ceil(
config.buffer.batch_size * (1 - config.explorer.over_rollout.ratio)
)
self.logger.info(
f"Over rollout is enabled. Explorer will only wait for {self.min_wait_num} tasks in each step."
)
else:
self.min_wait_num = None
self.rollout_coordinator = None
self.use_nccl_sync = self.config.synchronizer.sync_method == SyncMethod.NCCL
self.pending_eval_tasks = deque()
# For checkpoint weights update
# Use explorer to periodically load the latest model weights and
# boradcast to all rollout models
self.enable_lora = self.config.explorer.rollout_model.enable_lora
self.model_version = -1
self.sync_offset = config.synchronizer.sync_offset
self.sync_interval = config.synchronizer.sync_interval
self.sync_method = config.synchronizer.sync_method
self.sync_style = config.synchronizer.sync_style
self.eval_start_time = None
self.explore_start_time = None
self.logger.info("Finished initializing Explorer.")
[文档]
async def setup_weight_sync_group(
self, master_address: str, master_port: int, state_dict_meta: List = None
):
base_offset = 1 if self.use_nccl_sync else 0
world_size = (
len(self.models) * self.config.explorer.rollout_model.tensor_parallel_size + base_offset
)
self.logger.info(
f"Initialize process group for weight synchronization, "
f"master_address={master_address}, master_port={master_port}, "
f"world_size={world_size}, rank_offset={base_offset}"
)
# TODO: save state_dict in models
refs = [
model.init_process_group.remote(
master_address=master_address,
master_port=master_port,
rank_offset=i * self.config.explorer.rollout_model.tensor_parallel_size
+ base_offset,
world_size=world_size,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=self.config.synchronizer.sync_timeout,
state_dict_meta=state_dict_meta,
)
for i, model in enumerate(self.models)
]
await asyncio.gather(*refs)
[文档]
async def setup_model_level_weight_sync_group(self):
"""Setup process group for each model, only used in serve mode."""
refs = []
world_size = self.config.explorer.rollout_model.tensor_parallel_size
for model in self.models:
master_address, master_port = await model.get_available_address.remote()
self.logger.info(
f"Initialize process group for model weight synchronization, "
f"master_address={master_address}, master_port={master_port}, "
f"world_size={world_size}"
)
refs.append(
model.init_process_group.remote(
master_address=master_address,
master_port=master_port,
rank_offset=0,
world_size=world_size,
group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME,
explorer_name=self.config.explorer.name,
timeout=self.config.synchronizer.sync_timeout,
)
)
await asyncio.gather(*refs)
async def _checkpoint_weights_update(self, step_num: Optional[int] = None) -> int:
self.logger.info(f"Start to update model weights from checkpoint at step {step_num}.")
step_num = await self.synchronizer.set_model_state_dict_with_step_num.remote(step_num)
await asyncio.gather(*[model.sync_model.remote(step_num) for model in self.models])
self.logger.info(f"Model weights updated to checkpoint at step {step_num}.")
return step_num # type: ignore
async def _pull_latest_weights(self):
self.logger.info("Start to pull latest model weights.")
new_version = await self.synchronizer.wait_new_model_state_dict.remote(
current_version=self.model_version,
)
if new_version > self.model_version:
if self.model_version != -1 or new_version > 0:
self.logger.info(f"New model weights version: {new_version}")
await asyncio.gather(
*[model.sync_model.remote(new_version) for model in self.models]
)
self.model_version = new_version
else:
self.logger.warning(
f"No new model weights found, current version: {self.model_version}"
)
async def _nccl_weights_update(self):
new_version = await self.synchronizer.ready_to_nccl_sync.remote(
"explorer", self.model_version
)
if new_version is None:
self.logger.info("Trainer is not ready to sync weight. Skipping sync weight.")
return
self.model_version = new_version
await asyncio.gather(
*[model.sync_model.remote(self.model_version) for model in self.models]
)
[文档]
async def prepare(self) -> None:
"""Preparation before running."""
try:
# make sure all rollout models are ready
run_api_ref = [model.prepare.remote() for model in self.models]
run_api_ref.extend(
model.prepare.remote() for models in self.auxiliary_models for model in models
)
await asyncio.gather(*run_api_ref)
self.logger.info("All models are ready.")
if not self.use_nccl_sync and self.model_type not in {"tinker", "external"}:
if self.config.mode == "serve":
# In serving mode, each engine will setup its own process group
await self.setup_model_level_weight_sync_group()
else:
master_address, master_port = await self.models[
0
].get_available_address.remote()
await self.setup_weight_sync_group(master_address, master_port)
self.rollout_coordinator = RolloutCoordinator.get_actor(
self.config,
self.models,
self.auxiliary_models,
)
await self.rollout_coordinator.prepare.remote()
self.logger.info("Rollout coordinator is ready.")
if self.config.explorer.eval_on_startup and self.explore_step_num == 0:
await self.eval()
await self.synchronizer.set_explorer_status.remote(RunningStatus.RUNNING)
except Exception as e:
self.logger.error(f"Error during explorer preparation: {traceback.format_exc()}")
await self.shutdown()
raise e
[文档]
async def get_weight(self, name: str) -> torch.Tensor:
"""Get the weight of the loaded model (For checkpoint weights update)."""
return self.state_dict[name]
[文档]
async def explore(self) -> str:
"""
The timeline of the exploration process:
| <--------------------------------- one period -------------------------------------> |
explorer | <---------------- step_1 --------------> | |
| | <---------------- step_2 --------------> | |
| ... |
| | <---------------- step_n ---------------> | |
| | <---------------------- eval --------------------> | <-- sync --> |
|--------------------------------------------------------------------------------------|
trainer | <-- idle --> | <-- step_1 --> | <-- step_2 --> | ... | <-- step_n --> | <-- sync --> |
"""
while True:
try:
self.logger.info(f"Explore step {self.explore_step_num + 1} started.")
explore_contionue = await self.explore_step()
if not explore_contionue:
# TODO: support eval on last checkpoint
break
if self.need_eval():
await self.eval()
if await self.need_sync():
await self.sync_weight()
except Exception:
self.logger.error(f"Error in Explorer: {traceback.format_exc()}")
break
self.logger.info(
f"--------------------\n> Explorer ({self.config.explorer.name}) finished.\n--------------------"
)
return self.config.explorer.name
[文档]
async def explore_step(self) -> bool:
if self.explore_start_time is None:
self.explore_start_time = time.time()
try:
tasks = await self.taskset.read_async()
except StopAsyncIteration:
self.logger.warning("No more tasks to explore. Stop exploring.")
await self.finish_current_steps()
await self.save_checkpoint()
await self.synchronizer.set_explorer_status.remote(
RunningStatus.STOPPED,
old_status=RunningStatus.RUNNING,
)
await self.shutdown()
return False
self.explore_step_num += 1
assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first."
await self.rollout_coordinator.submit_batch.remote(
batch_id=self.explore_step_num,
tasks=tasks,
batch_type="train",
min_wait_num=self.min_wait_num,
)
return True
[文档]
async def finish_current_steps(self) -> None:
if self.rollout_coordinator is not None:
await self._finish_steps(
self.last_monitored_step + 1, self.explore_step_num, self.model_version
)
self.last_monitored_step = self.explore_step_num
[文档]
async def need_sync(self) -> bool:
if self.explore_step_num <= self.sync_offset:
return False
require_sync = False
if (self.explore_step_num - self.sync_offset) % self.sync_interval == 0:
await self.finish_current_steps()
if self.sync_style == SyncStyle.TRAINER_DRIVEN and self.sync_method == SyncMethod.NCCL:
require_sync = await self.synchronizer.trainer_requires_sync.remote()
else:
require_sync = True
return require_sync
[文档]
def need_eval(self) -> bool:
return self.explore_step_num % self.config.explorer.eval_interval == 0
[文档]
async def eval(self):
"""Evaluation on all evaluation data samples."""
self.eval_start_time = time.time()
if len(self.config.buffer.explorer_input.eval_tasksets) == 0:
self.logger.warning("No evaluation data samples. Skip evaluation.")
return
self.logger.info(f"Evaluation at step {self.explore_step_num} started.")
if self.config.buffer.explorer_input.default_eval_workflow_type:
self.logger.info(
f"Use '{self.config.buffer.explorer_input.default_eval_workflow_type}' for evaluation."
)
for eval_taskset_config in self.config.buffer.explorer_input.eval_tasksets:
self.logger.info(
f"Evaluation on {eval_taskset_config.name} at step {self.explore_step_num} started."
)
eval_taskset = get_buffer_reader(eval_taskset_config)
eval_batch_id = f"{self.explore_step_num}/{eval_taskset_config.name}"
self.pending_eval_tasks.append((self.explore_step_num, eval_taskset_config.name))
eval_tasks = []
while True:
try:
eval_tasks.extend(await eval_taskset.read_async())
except StopAsyncIteration:
break
assert (
self.rollout_coordinator is not None
), "Rollout coordinator must be prepared first."
await self.rollout_coordinator.submit_batch.remote(
batch_id=eval_batch_id,
tasks=eval_tasks,
batch_type="eval",
)
[文档]
async def benchmark(self) -> bool:
"""Benchmark the model checkpoints."""
# benchmark on the latest checkpoint
if self.config.explorer.bench_on_latest_checkpoint:
self.explore_step_num = await self._checkpoint_weights_update()
await self.eval()
await self._finish_eval_step(prefix="bench")
return True
# benchmark on base model
if self.config.explorer.eval_on_startup:
await self._finish_eval_step(prefix="bench")
# benchmark on all checkpoints
all_ckp_steps = sorted(
[
int(ckp.split("global_step_")[-1])
for ckp in os.listdir(self.config.checkpoint_job_dir)
if os.path.isdir(os.path.join(self.config.checkpoint_job_dir, ckp))
and ckp.startswith("global_step_")
]
)
for step_num in all_ckp_steps:
if step_num <= self.explore_step_num:
continue
self.explore_step_num = await self._checkpoint_weights_update(step_num=step_num)
await self.eval()
await self._finish_eval_step(prefix="bench")
return True
[文档]
async def save_checkpoint(self) -> None:
# save explore checkpoint
self.state.save_explorer(
current_step=self.explore_step_num,
taskset_states=self.taskset.state_dict() if self.taskset else [],
)
[文档]
async def sync_weight(self) -> None:
"""Synchronize model weights."""
# call this method before training start to load the latest model weights
if self.rollout_coordinator is not None and self.explore_step_num == 0:
await self._finish_eval_step(step=0)
self.logger.info(f"Explorer sync_weights at step {self.explore_step_num} started.")
if self.use_nccl_sync:
await self._nccl_weights_update()
else: # pull weights from Synchronizer
await self._pull_latest_weights()
self.logger.info(
f"Explorer sync_weights at step {self.explore_step_num} finished, model version = {self.model_version}."
)
await self.save_checkpoint()
async def _finish_steps(self, start_step: int, end_step: int, model_version: int) -> None:
for step in range(start_step, end_step + 1):
self.logger.info(f"Waiting for step {step}")
await self._finish_explore_step(step=step, model_version=model_version)
await self._finish_eval_step(step=step)
# Record the time: read_task + explore_step (>=1) + eval (if any)
if self.explore_start_time is not None:
metric = {"time/explorer_sync_interval": time.time() - self.explore_start_time}
self.explore_start_time = None
if self.monitor is not None:
self.monitor.log(metric, step=end_step)
async def _finish_explore_step(self, step: int, model_version: int) -> None:
assert self.rollout_coordinator is not None, "Rollout coordinator must be prepared first."
metric = {"rollout/model_version": model_version}
with Timer(metric, "time/wait_explore_step"):
result = await self.rollout_coordinator.finalize_train_batch.remote(step)
if self.taskset is not None:
self.taskset.feedback(result["metrics"])
metric.update(result["metrics"])
if result["finished_task_count"] > 0 and self.monitor is not None:
self.monitor.log(metric, step=step)
async def _finish_eval_step(self, step: Optional[int] = None, prefix: str = "eval") -> None:
if not self.pending_eval_tasks:
return
step = step or self.explore_step_num
metric = {}
while self.pending_eval_tasks:
eval_step, eval_task_name = self.pending_eval_tasks[0]
if eval_step != step:
return
self.pending_eval_tasks.popleft()
assert (
self.rollout_coordinator is not None
), "Rollout coordinator must be prepared first."
result = await self.rollout_coordinator.finalize_eval_batch.remote(
f"{step}/{eval_task_name}"
)
batch_metrics = result["metrics"]
if prefix != "eval":
batch_metrics = {
key.replace("eval/", f"{prefix}/", 1) if key.startswith("eval/") else key: value
for key, value in batch_metrics.items()
}
metric.update(batch_metrics)
if self.eval_start_time is not None:
metric.update({f"time/{prefix}": time.time() - self.eval_start_time})
self.eval_start_time = None
if self.monitor is not None:
self.monitor.log(metric, step)
[文档]
async def shutdown(self) -> None:
if self.rollout_coordinator:
await self.rollout_coordinator.shutdown.remote()
self.rollout_coordinator = None
if self.monitor:
self.monitor.close()
self.monitor = None
handlers = []
for model in self.models:
handlers.append(model.shutdown.remote())
for auxiliary_model_list in self.auxiliary_models:
for model in auxiliary_model_list:
handlers.append(model.shutdown.remote())
await asyncio.gather(*handlers)
self.logger.info(
f"Explorer ({self.config.explorer.name}) shutdown successfully at step {self.explore_step_num}."
)
[文档]
async def is_alive(self) -> bool:
"""Check if the explorer is alive."""
return True
[文档]
@Experimental
async def serve(self) -> None:
"""Run the explorer in serving mode.
In serving mode, the explorer starts an OpenAI compatible server to handle requests.
Agent applications can be deployed separately and interact with the explorer via the API.
.. code-block:: python
import openai
client = openai.OpenAI(
base_url=f"{explorer_server_url}/v1",
api_key="EMPTY",
)
response = client.chat.completions.create(
model=config.model.model_path,
messages=[{"role": "user", "content": "Hello!"}]
)
"""
from trinity.explorer.proxy.service import ExplorerService
self.service = ExplorerService(
self,
listen_address=self.config.explorer.listen_address,
port=self.config.explorer.proxy_port,
)
await self.service.serve()
self.server_url = f"http://{ray.util.get_node_ip_address()}:{self.service.port}"
self.logger.info(
"======================================================\n"
f"Starting Trinity Service on {self.server_url}\n"
"======================================================"
)
self.state.save_explorer_server_url(self.server_url)
while True:
await asyncio.sleep(self.config.explorer.service_status_check_interval)
# get the latest checkpoint
model_version = await self.synchronizer.get_latest_model_version.remote()
self.service.set_latest_model_version(model_version)
[文档]
@classmethod
def get_actor(cls, config: Config):
"""Get a Ray actor for the explorer."""
return (
ray.remote(cls)
.options(
name=config.explorer.name,
namespace=config.ray_namespace,
)
.remote(config)
)