trinity.utils.stream_saver module#

Streaming safetensors saver — saves one tensor at a time to avoid OOM.

Standard safetensors.torch.save_file requires the entire state dict in memory. For 70B+ models this can exceed available CPU RAM. This module provides save_safetensors_streaming() which iterates a tensor generator and writes each tensor directly to disk, keeping only one tensor in memory at any point.

The implementation uses a seek-back approach:

  1. Reserve space for the safetensors header (exactly sized from pre-collected state_dict_meta).

  2. Write tensor data sequentially, collecting metadata along the way.

  3. Build the JSON header and seek back to the beginning to write it.

The resulting file is a fully valid safetensors file that can be loaded by safetensors.torch.load_file (which uses mmap internally).

Writes go through the OS page cache, so the function returns quickly after the last tensor is written. Callers can hand off fsync + rename to a background thread for minimal main-thread blocking.

class trinity.utils.stream_saver.TensorMeta(name: str, dtype: str, shape: list[int])[源代码]#

基类:NamedTuple

Metadata for a single tensor, sufficient to compute its safetensors header entry.

name: str#

Alias for field number 0

dtype: str#

Alias for field number 1

shape: list[int]#

Alias for field number 2

trinity.utils.stream_saver.save_safetensors_streaming(tensor_iter: Iterable[tuple[str, Tensor]], filepath: str | PathLike, state_dict_meta: Iterable[TensorMeta | tuple[str, str | dtype, Sequence[int]]], *, rename: bool = True) str[源代码]#

Stream-write a safetensors file, holding only one tensor in memory.

参数:
  • tensor_iter -- Yields (name, tensor) pairs. Each tensor must already reside on CPU. After the tensor's bytes are written it is no longer referenced — the caller should del the tensor from their side as well.

  • filepath -- Destination path. Data is first written to filepath + ".tmp"; if rename is True the temp file is atomically renamed to filepath before returning.

  • state_dict_meta -- Pre-collected (name, dtype, shape) metadata for every tensor that tensor_iter will yield. Supports both the internal TensorMeta format and plain tuples such as (name, "bfloat16", (hidden, hidden)). The header space is computed exactly from this metadata, so the header will always fit the reserved space without any rewrite. Callers should collect this once during initialization via _cache_state_dict_meta() and reuse it for every save.

  • rename -- If True (default), atomically replace filepath with the completed temp file. If False, leave the temp file on disk — the caller is responsible for fsync + rename (useful for hybrid async mode).

返回:

filepath when rename=True, otherwise str(filepath) + ".tmp".

返回类型:

The path actually written