Source code for trinity.buffer.reader.queue_reader
"""Reader of the Queue buffer."""
import traceback
from typing import Dict, List, Optional
from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.storage.queue import QueueStorage
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
[docs]
class QueueReader(BufferReader):
"""Reader of the Queue buffer."""
[docs]
def __init__(self, config: StorageConfig):
assert config.storage_type == StorageType.QUEUE.value
self.timeout = config.max_read_timeout
self.read_batch_size = config.batch_size
self.queue = QueueStorage.get_wrapper(config)
[docs]
async def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]:
batch_size = self.read_batch_size if batch_size is None else batch_size
try:
exp_bytes = await self.queue.get_batch.remote(
batch_size, timeout=self.timeout, **kwargs
)
except Exception as e:
if "StopAsyncIteration" in traceback.format_exc():
raise StopAsyncIteration() from e
else:
raise
exps = Experience.deserialize_many(exp_bytes)
if len(exps) != batch_size:
raise TimeoutError(
f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow."
)
return exps
[docs]
def state_dict(self) -> Dict:
return {"current_index": 0}
[docs]
def load_state_dict(self, state_dict):
return None