Source code for trinity.buffer.reader.sql_reader
"""Reader of the SQL buffer."""
import traceback
from typing import Dict, List, Optional
from trinity.buffer.buffer_reader import BufferReader
from trinity.buffer.storage.sql import SQLExperienceStorage, SQLStorage, SQLTaskStorage
from trinity.common.config import StorageConfig
from trinity.common.constants import StorageType
[docs]
class SQLReader(BufferReader):
"""Reader of the SQL buffer."""
[docs]
def __init__(self, config: StorageConfig) -> None:
assert config.storage_type == StorageType.SQL.value
self.wrap_in_ray = config.wrap_in_ray
self.read_batch_size = config.batch_size
self._storage = None
self._async_storage = None
self._config = config
@property
def storage(self):
if self._storage is None:
self._storage = SQLStorage.get_wrapper(self._config)
return self._storage
async def _get_async_storage(self):
if self._async_storage is None:
if self._config.schema_type is None:
self._async_storage = SQLTaskStorage(self._config)
else:
self._async_storage = SQLExperienceStorage(self._config)
await self._async_storage.prepare()
return self._async_storage
[docs]
async def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = self.read_batch_size if batch_size is None else batch_size
if self.wrap_in_ray:
try:
return await self.storage.read.remote(batch_size, **kwargs)
except (StopIteration, StopAsyncIteration):
raise StopAsyncIteration()
except Exception as e:
if "StopAsyncIteration" in traceback.format_exc():
raise StopAsyncIteration() from e
raise
else:
storage = await self._get_async_storage()
return await storage.read(batch_size, **kwargs)
[docs]
def state_dict(self) -> Dict:
return {"current_index": 0}
[docs]
def load_state_dict(self, state_dict):
return None