Vendor qroissant 0.3.0 baseline
This commit is contained in:
commit
53ac90fe84
56 changed files with 18309 additions and 0 deletions
475
crates/qroissant-transport/src/asynchronous.rs
Normal file
475
crates/qroissant-transport/src/asynchronous.rs
Normal file
|
|
@ -0,0 +1,475 @@
|
|||
use std::pin::Pin;
|
||||
use std::task::Context;
|
||||
use std::task::Poll;
|
||||
use std::time::Duration;
|
||||
|
||||
use qroissant_core::Compression;
|
||||
use qroissant_core::HEADER_LEN;
|
||||
use qroissant_core::MessageHeader;
|
||||
use qroissant_core::StreamingDecompressor;
|
||||
use qroissant_core::read_message_length;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::io::ReadBuf;
|
||||
use tokio::net::TcpStream;
|
||||
#[cfg(unix)]
|
||||
use tokio::net::UnixStream;
|
||||
|
||||
use crate::TransportError;
|
||||
use crate::TransportResult;
|
||||
use crate::synchronous::CLIENT_CAPABILITY;
|
||||
|
||||
pub enum AsyncTransport {
|
||||
Tcp(TcpStream),
|
||||
#[cfg(unix)]
|
||||
Unix(UnixStream),
|
||||
}
|
||||
|
||||
impl AsyncRead for AsyncTransport {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
#[cfg(unix)]
|
||||
Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for AsyncTransport {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
match &mut *self {
|
||||
Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
#[cfg(unix)]
|
||||
Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Tcp(stream) => Pin::new(stream).poll_flush(cx),
|
||||
#[cfg(unix)]
|
||||
Self::Unix(stream) => Pin::new(stream).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
match &mut *self {
|
||||
Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
#[cfg(unix)]
|
||||
Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncTransport {
|
||||
pub async fn shutdown(&mut self) -> std::io::Result<()> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.shutdown().await,
|
||||
#[cfg(unix)]
|
||||
Self::Unix(stream) => stream.shutdown().await,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
|
||||
match self {
|
||||
Self::Tcp(stream) => stream.take_error(),
|
||||
#[cfg(unix)]
|
||||
Self::Unix(stream) => stream.take_error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct AsyncPooledTransport {
|
||||
transport: AsyncTransport,
|
||||
broken: bool,
|
||||
}
|
||||
|
||||
impl AsyncPooledTransport {
|
||||
pub fn new(transport: AsyncTransport) -> Self {
|
||||
Self {
|
||||
transport,
|
||||
broken: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mark_broken(&mut self) {
|
||||
self.broken = true;
|
||||
}
|
||||
|
||||
pub fn is_broken(&self) -> bool {
|
||||
self.broken || self.transport.take_error().ok().flatten().is_some()
|
||||
}
|
||||
|
||||
pub fn transport_mut(&mut self) -> &mut AsyncTransport {
|
||||
&mut self.transport
|
||||
}
|
||||
}
|
||||
|
||||
/// A reader that transparently decompresses q IPC payloads as they are read.
|
||||
pub struct DecompressingReader<'a, R> {
|
||||
reader: &'a mut R,
|
||||
decompressor: Option<StreamingDecompressor>,
|
||||
remaining_compressed: usize,
|
||||
buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<'a, R: AsyncRead + Unpin> DecompressingReader<'a, R> {
|
||||
pub fn new(
|
||||
reader: &'a mut R,
|
||||
decompressor: Option<StreamingDecompressor>,
|
||||
remaining_compressed: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
reader,
|
||||
decompressor,
|
||||
remaining_compressed,
|
||||
buffer: vec![0_u8; 8192],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, R: AsyncRead + Unpin> AsyncRead for DecompressingReader<'a, R> {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let this = &mut *self;
|
||||
|
||||
if let Some(decompressor) = &mut this.decompressor {
|
||||
// If we have decompressed data available, yield it first.
|
||||
if decompressor.unread_len() > 0 {
|
||||
let chunk = decompressor.next_chunk();
|
||||
let to_copy = chunk.len().min(buf.remaining());
|
||||
buf.put_slice(&chunk[..to_copy]);
|
||||
decompressor.consume(to_copy);
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// If decompression is complete and no more unread bytes, EOF.
|
||||
if decompressor.is_complete() {
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
|
||||
// Otherwise, read more compressed data from the underlying reader.
|
||||
if this.remaining_compressed > 0 {
|
||||
let want = this.remaining_compressed.min(this.buffer.len());
|
||||
let mut read_buf = ReadBuf::new(&mut this.buffer[..want]);
|
||||
match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
|
||||
Poll::Ready(Ok(())) => {
|
||||
let read = read_buf.filled().len();
|
||||
if read == 0 && want > 0 {
|
||||
return Poll::Ready(Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"unexpected EOF reading compressed body",
|
||||
)));
|
||||
}
|
||||
this.remaining_compressed -= read;
|
||||
decompressor.feed(read_buf.filled()).map_err(|e| {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
|
||||
})?;
|
||||
|
||||
// Recursive call to yield the newly decompressed bytes.
|
||||
return self.poll_read(cx, buf);
|
||||
}
|
||||
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
||||
Poll::Pending => return Poll::Pending,
|
||||
}
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
// Uncompressed path: direct read from underlying reader.
|
||||
Pin::new(&mut this.reader).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for AsyncPooledTransport {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.transport).poll_read(cx, buf)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for AsyncPooledTransport {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<std::io::Result<usize>> {
|
||||
Pin::new(&mut self.transport).poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.transport).poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
||||
Pin::new(&mut self.transport).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
|
||||
fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec<u8> {
|
||||
let username = username.unwrap_or_default();
|
||||
let password = password.unwrap_or_default();
|
||||
let mut bytes = format!("{username}:{password}").into_bytes();
|
||||
bytes.push(CLIENT_CAPABILITY);
|
||||
bytes.push(0);
|
||||
bytes
|
||||
}
|
||||
|
||||
fn timeout_error(context: &str, timeout_ms: u64) -> TransportError {
|
||||
TransportError::Io(std::io::Error::new(
|
||||
std::io::ErrorKind::TimedOut,
|
||||
format!("{context} timed out after {timeout_ms}ms"),
|
||||
))
|
||||
}
|
||||
|
||||
async fn run_with_timeout<T, F>(
|
||||
timeout_ms: Option<u64>,
|
||||
context: &str,
|
||||
future: F,
|
||||
) -> TransportResult<T>
|
||||
where
|
||||
F: std::future::Future<Output = std::io::Result<T>>,
|
||||
{
|
||||
match timeout_ms {
|
||||
Some(timeout_ms) => tokio::time::timeout(Duration::from_millis(timeout_ms), future)
|
||||
.await
|
||||
.map_err(|_| timeout_error(context, timeout_ms))?
|
||||
.map_err(TransportError::Io),
|
||||
None => future.await.map_err(TransportError::Io),
|
||||
}
|
||||
}
|
||||
|
||||
async fn perform_handshake<S>(
|
||||
stream: &mut S,
|
||||
username: Option<&str>,
|
||||
password: Option<&str>,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> TransportResult<u8>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
run_with_timeout(
|
||||
timeout_ms,
|
||||
"q IPC handshake write",
|
||||
stream.write_all(&credentials_bytes(username, password)),
|
||||
)
|
||||
.await?;
|
||||
run_with_timeout(timeout_ms, "q IPC handshake flush", stream.flush()).await?;
|
||||
|
||||
let mut capability = [0_u8; 1];
|
||||
run_with_timeout(
|
||||
timeout_ms,
|
||||
"q IPC handshake read",
|
||||
stream.read_exact(&mut capability),
|
||||
)
|
||||
.await?;
|
||||
Ok(capability[0])
|
||||
}
|
||||
|
||||
pub async fn connect_tcp_transport(
|
||||
host: &str,
|
||||
port: u16,
|
||||
username: Option<&str>,
|
||||
password: Option<&str>,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> TransportResult<AsyncTransport> {
|
||||
let mut stream =
|
||||
run_with_timeout(timeout_ms, "TCP connect", TcpStream::connect((host, port))).await?;
|
||||
stream.set_nodelay(true)?;
|
||||
perform_handshake(&mut stream, username, password, timeout_ms).await?;
|
||||
Ok(AsyncTransport::Tcp(stream))
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
pub async fn connect_unix_transport(
|
||||
path: &str,
|
||||
username: Option<&str>,
|
||||
password: Option<&str>,
|
||||
timeout_ms: Option<u64>,
|
||||
) -> TransportResult<AsyncTransport> {
|
||||
let mut stream =
|
||||
run_with_timeout(timeout_ms, "Unix socket connect", UnixStream::connect(path)).await?;
|
||||
perform_handshake(&mut stream, username, password, timeout_ms).await?;
|
||||
Ok(AsyncTransport::Unix(stream))
|
||||
}
|
||||
|
||||
pub async fn read_frame<S>(stream: &mut S) -> TransportResult<Vec<u8>>
|
||||
where
|
||||
S: AsyncRead + Unpin,
|
||||
{
|
||||
let mut header = [0_u8; HEADER_LEN];
|
||||
stream.read_exact(&mut header).await?;
|
||||
let message_length = read_message_length(&header)
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
||||
let mut frame = vec![0_u8; message_length];
|
||||
frame[..HEADER_LEN].copy_from_slice(&header);
|
||||
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
|
||||
Ok(frame)
|
||||
}
|
||||
|
||||
pub async fn request_frame_over<S>(stream: &mut S, payload: &[u8]) -> TransportResult<Vec<u8>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
stream.write_all(payload).await?;
|
||||
stream.flush().await?;
|
||||
read_frame(stream).await
|
||||
}
|
||||
|
||||
pub async fn begin_streaming_frame_over<S>(
|
||||
stream: &mut S,
|
||||
payload: &[u8],
|
||||
) -> TransportResult<([u8; HEADER_LEN], usize)>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
stream.write_all(payload).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
let mut header = [0_u8; HEADER_LEN];
|
||||
stream.read_exact(&mut header).await?;
|
||||
let message_length = read_message_length(&header)
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
||||
Ok((header, message_length - HEADER_LEN))
|
||||
}
|
||||
|
||||
/// Async variant of [`crate::synchronous::request_frame_streaming_over`].
|
||||
///
|
||||
/// Sends a payload and reads the response frame, using streaming decompression
|
||||
/// when the response is compressed.
|
||||
pub async fn request_frame_streaming_over<S>(
|
||||
stream: &mut S,
|
||||
payload: &[u8],
|
||||
) -> TransportResult<Vec<u8>>
|
||||
where
|
||||
S: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
stream.write_all(payload).await?;
|
||||
stream.flush().await?;
|
||||
|
||||
// Read the 8-byte header.
|
||||
let mut header_bytes = [0_u8; HEADER_LEN];
|
||||
stream.read_exact(&mut header_bytes).await?;
|
||||
|
||||
let header = MessageHeader::from_bytes(header_bytes)
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
||||
let body_len = header.body_len();
|
||||
|
||||
if header.compression() == Compression::Uncompressed {
|
||||
let mut frame = vec![0_u8; header.size()];
|
||||
frame[..HEADER_LEN].copy_from_slice(&header_bytes);
|
||||
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
|
||||
return Ok(frame);
|
||||
}
|
||||
|
||||
// Compressed frame: read the 4-byte size prefix first.
|
||||
if body_len < 4 {
|
||||
return Err(TransportError::Protocol(
|
||||
"compressed body must be at least 4 bytes for size prefix".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let mut size_prefix = [0_u8; 4];
|
||||
stream.read_exact(&mut size_prefix).await?;
|
||||
|
||||
let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
||||
|
||||
// Read the remaining compressed body in chunks.
|
||||
let remaining = body_len - 4;
|
||||
let mut total_read = 0_usize;
|
||||
let mut chunk = vec![0_u8; 8192];
|
||||
|
||||
while total_read < remaining {
|
||||
let want = (remaining - total_read).min(chunk.len());
|
||||
stream.read_exact(&mut chunk[..want]).await?;
|
||||
decompressor
|
||||
.feed(&chunk[..want])
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
||||
total_read += want;
|
||||
}
|
||||
|
||||
if !decompressor.is_complete() {
|
||||
return Err(TransportError::Protocol(
|
||||
"streaming decompression did not complete after reading entire body".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let decompressed = decompressor
|
||||
.finish()
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
||||
|
||||
// Reconstruct as an uncompressed frame.
|
||||
let new_size = HEADER_LEN + decompressed.len();
|
||||
let new_header = qroissant_core::MessageHeader::new(
|
||||
header.encoding(),
|
||||
header.message_type(),
|
||||
Compression::Uncompressed,
|
||||
new_size,
|
||||
)
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
||||
|
||||
let mut frame = Vec::with_capacity(new_size);
|
||||
frame.extend_from_slice(
|
||||
&new_header
|
||||
.to_bytes()
|
||||
.map_err(|error| TransportError::Protocol(error.to_string()))?,
|
||||
);
|
||||
frame.extend_from_slice(&decompressed);
|
||||
Ok(frame)
|
||||
}
|
||||
use qroissant_core::pipelined::PipelinedReader;
|
||||
use qroissant_core::pipelined::decode_value_async;
|
||||
use qroissant_core::value::Value;
|
||||
|
||||
pub async fn request_value_pipelined_over<R: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
||||
conn: &mut R,
|
||||
payload: &[u8],
|
||||
) -> TransportResult<Value> {
|
||||
conn.write_all(payload).await.map_err(TransportError::Io)?;
|
||||
conn.flush().await.map_err(TransportError::Io)?;
|
||||
|
||||
let mut header_bytes = [0_u8; HEADER_LEN];
|
||||
conn.read_exact(&mut header_bytes)
|
||||
.await
|
||||
.map_err(TransportError::Io)?;
|
||||
let header =
|
||||
MessageHeader::parse(&header_bytes).map_err(|e| TransportError::Protocol(e.to_string()))?;
|
||||
|
||||
let (decompressor, remaining_compressed) = if header.compression() != Compression::Uncompressed
|
||||
{
|
||||
let mut size_prefix = [0_u8; 4];
|
||||
conn.read_exact(&mut size_prefix)
|
||||
.await
|
||||
.map_err(TransportError::Io)?;
|
||||
let decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
|
||||
.map_err(|e| TransportError::Protocol(e.to_string()))?;
|
||||
(Some(decompressor), header.body_len() - 4)
|
||||
} else {
|
||||
(None, header.body_len())
|
||||
};
|
||||
|
||||
let mut decomp_reader = DecompressingReader::new(conn, decompressor, remaining_compressed);
|
||||
let mut pipelined_reader = PipelinedReader::new(&mut decomp_reader, header.encoding())
|
||||
.map_err(|e| TransportError::Protocol(e.to_string()))?;
|
||||
|
||||
decode_value_async(&mut pipelined_reader)
|
||||
.await
|
||||
.map_err(|e| TransportError::Protocol(e.to_string()))
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue