trinity.utils.stream_saver module#

Streaming safetensors saver — saves one tensor at a time to avoid OOM.

Standard safetensors.torch.save_file requires the entire state dict in memory. For 70B+ models this can exceed available CPU RAM. This module provides save_safetensors_streaming() which iterates a tensor generator and writes each tensor directly to disk, keeping only one tensor in memory at any point.

The implementation uses a seek-back approach:

  1. Reserve space for the safetensors header (exactly sized from pre-collected state_dict_meta).

  2. Write tensor data sequentially, collecting metadata along the way.

  3. Build the JSON header and seek back to the beginning to write it.

The resulting file is a fully valid safetensors file that can be loaded by safetensors.torch.load_file (which uses mmap internally).

Writes go through the OS page cache, so the function returns quickly after the last tensor is written. Callers can hand off fsync + rename to a background thread for minimal main-thread blocking.

class trinity.utils.stream_saver.TensorMeta(name: str, dtype: str, shape: list[int])[source]#

Bases: NamedTuple

Metadata for a single tensor, sufficient to compute its safetensors header entry.

name: str#

Alias for field number 0

dtype: str#

Alias for field number 1

shape: list[int]#

Alias for field number 2

trinity.utils.stream_saver.save_safetensors_streaming(tensor_iter: Iterable[tuple[str, Tensor]], filepath: str | PathLike, state_dict_meta: Iterable[TensorMeta | tuple[str, str | dtype, Sequence[int]]], *, rename: bool = True) str[source]#

Stream-write a safetensors file, holding only one tensor in memory.

Parameters:
  • tensor_iter – Yields (name, tensor) pairs. Each tensor must already reside on CPU. After the tensor’s bytes are written it is no longer referenced — the caller should del the tensor from their side as well.

  • filepath – Destination path. Data is first written to filepath + ".tmp"; if rename is True the temp file is atomically renamed to filepath before returning.

  • state_dict_meta – Pre-collected (name, dtype, shape) metadata for every tensor that tensor_iter will yield. Supports both the internal TensorMeta format and plain tuples such as (name, "bfloat16", (hidden, hidden)). The header space is computed exactly from this metadata, so the header will always fit the reserved space without any rewrite. Callers should collect this once during initialization via _cache_state_dict_meta() and reuse it for every save.

  • rename – If True (default), atomically replace filepath with the completed temp file. If False, leave the temp file on disk — the caller is responsible for fsync + rename (useful for hybrid async mode).

Returns:

filepath when rename=True, otherwise str(filepath) + ".tmp".

Return type:

The path actually written