Vendor qroissant 0.3.0 baseline

This commit is contained in:
Cam Zalewski 2026-05-20 14:11:30 +01:00
commit 53ac90fe84
56 changed files with 18309 additions and 0 deletions

View file

@ -0,0 +1,475 @@
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use qroissant_core::Compression;
use qroissant_core::HEADER_LEN;
use qroissant_core::MessageHeader;
use qroissant_core::StreamingDecompressor;
use qroissant_core::read_message_length;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use crate::TransportError;
use crate::TransportResult;
use crate::synchronous::CLIENT_CAPABILITY;
pub enum AsyncTransport {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
impl AsyncRead for AsyncTransport {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for AsyncTransport {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl AsyncTransport {
pub async fn shutdown(&mut self) -> std::io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown().await,
#[cfg(unix)]
Self::Unix(stream) => stream.shutdown().await,
}
}
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
match self {
Self::Tcp(stream) => stream.take_error(),
#[cfg(unix)]
Self::Unix(stream) => stream.take_error(),
}
}
}
pub struct AsyncPooledTransport {
transport: AsyncTransport,
broken: bool,
}
impl AsyncPooledTransport {
pub fn new(transport: AsyncTransport) -> Self {
Self {
transport,
broken: false,
}
}
pub fn mark_broken(&mut self) {
self.broken = true;
}
pub fn is_broken(&self) -> bool {
self.broken || self.transport.take_error().ok().flatten().is_some()
}
pub fn transport_mut(&mut self) -> &mut AsyncTransport {
&mut self.transport
}
}
/// A reader that transparently decompresses q IPC payloads as they are read.
pub struct DecompressingReader<'a, R> {
reader: &'a mut R,
decompressor: Option<StreamingDecompressor>,
remaining_compressed: usize,
buffer: Vec<u8>,
}
impl<'a, R: AsyncRead + Unpin> DecompressingReader<'a, R> {
pub fn new(
reader: &'a mut R,
decompressor: Option<StreamingDecompressor>,
remaining_compressed: usize,
) -> Self {
Self {
reader,
decompressor,
remaining_compressed,
buffer: vec![0_u8; 8192],
}
}
}
impl<'a, R: AsyncRead + Unpin> AsyncRead for DecompressingReader<'a, R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = &mut *self;
if let Some(decompressor) = &mut this.decompressor {
// If we have decompressed data available, yield it first.
if decompressor.unread_len() > 0 {
let chunk = decompressor.next_chunk();
let to_copy = chunk.len().min(buf.remaining());
buf.put_slice(&chunk[..to_copy]);
decompressor.consume(to_copy);
return Poll::Ready(Ok(()));
}
// If decompression is complete and no more unread bytes, EOF.
if decompressor.is_complete() {
return Poll::Ready(Ok(()));
}
// Otherwise, read more compressed data from the underlying reader.
if this.remaining_compressed > 0 {
let want = this.remaining_compressed.min(this.buffer.len());
let mut read_buf = ReadBuf::new(&mut this.buffer[..want]);
match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
let read = read_buf.filled().len();
if read == 0 && want > 0 {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF reading compressed body",
)));
}
this.remaining_compressed -= read;
decompressor.feed(read_buf.filled()).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
})?;
// Recursive call to yield the newly decompressed bytes.
return self.poll_read(cx, buf);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
} else {
// Uncompressed path: direct read from underlying reader.
Pin::new(&mut this.reader).poll_read(cx, buf)
}
}
}
impl AsyncRead for AsyncPooledTransport {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.transport).poll_read(cx, buf)
}
}
impl AsyncWrite for AsyncPooledTransport {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.transport).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.transport).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.transport).poll_shutdown(cx)
}
}
fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec<u8> {
let username = username.unwrap_or_default();
let password = password.unwrap_or_default();
let mut bytes = format!("{username}:{password}").into_bytes();
bytes.push(CLIENT_CAPABILITY);
bytes.push(0);
bytes
}
fn timeout_error(context: &str, timeout_ms: u64) -> TransportError {
TransportError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("{context} timed out after {timeout_ms}ms"),
))
}
async fn run_with_timeout<T, F>(
timeout_ms: Option<u64>,
context: &str,
future: F,
) -> TransportResult<T>
where
F: std::future::Future<Output = std::io::Result<T>>,
{
match timeout_ms {
Some(timeout_ms) => tokio::time::timeout(Duration::from_millis(timeout_ms), future)
.await
.map_err(|_| timeout_error(context, timeout_ms))?
.map_err(TransportError::Io),
None => future.await.map_err(TransportError::Io),
}
}
async fn perform_handshake<S>(
stream: &mut S,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<u8>
where
S: AsyncRead + AsyncWrite + Unpin,
{
run_with_timeout(
timeout_ms,
"q IPC handshake write",
stream.write_all(&credentials_bytes(username, password)),
)
.await?;
run_with_timeout(timeout_ms, "q IPC handshake flush", stream.flush()).await?;
let mut capability = [0_u8; 1];
run_with_timeout(
timeout_ms,
"q IPC handshake read",
stream.read_exact(&mut capability),
)
.await?;
Ok(capability[0])
}
pub async fn connect_tcp_transport(
host: &str,
port: u16,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<AsyncTransport> {
let mut stream =
run_with_timeout(timeout_ms, "TCP connect", TcpStream::connect((host, port))).await?;
stream.set_nodelay(true)?;
perform_handshake(&mut stream, username, password, timeout_ms).await?;
Ok(AsyncTransport::Tcp(stream))
}
#[cfg(unix)]
pub async fn connect_unix_transport(
path: &str,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<AsyncTransport> {
let mut stream =
run_with_timeout(timeout_ms, "Unix socket connect", UnixStream::connect(path)).await?;
perform_handshake(&mut stream, username, password, timeout_ms).await?;
Ok(AsyncTransport::Unix(stream))
}
pub async fn read_frame<S>(stream: &mut S) -> TransportResult<Vec<u8>>
where
S: AsyncRead + Unpin,
{
let mut header = [0_u8; HEADER_LEN];
stream.read_exact(&mut header).await?;
let message_length = read_message_length(&header)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
let mut frame = vec![0_u8; message_length];
frame[..HEADER_LEN].copy_from_slice(&header);
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
Ok(frame)
}
pub async fn request_frame_over<S>(stream: &mut S, payload: &[u8]) -> TransportResult<Vec<u8>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
stream.write_all(payload).await?;
stream.flush().await?;
read_frame(stream).await
}
pub async fn begin_streaming_frame_over<S>(
stream: &mut S,
payload: &[u8],
) -> TransportResult<([u8; HEADER_LEN], usize)>
where
S: AsyncRead + AsyncWrite + Unpin,
{
stream.write_all(payload).await?;
stream.flush().await?;
let mut header = [0_u8; HEADER_LEN];
stream.read_exact(&mut header).await?;
let message_length = read_message_length(&header)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
Ok((header, message_length - HEADER_LEN))
}
/// Async variant of [`crate::synchronous::request_frame_streaming_over`].
///
/// Sends a payload and reads the response frame, using streaming decompression
/// when the response is compressed.
pub async fn request_frame_streaming_over<S>(
stream: &mut S,
payload: &[u8],
) -> TransportResult<Vec<u8>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
stream.write_all(payload).await?;
stream.flush().await?;
// Read the 8-byte header.
let mut header_bytes = [0_u8; HEADER_LEN];
stream.read_exact(&mut header_bytes).await?;
let header = MessageHeader::from_bytes(header_bytes)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
let body_len = header.body_len();
if header.compression() == Compression::Uncompressed {
let mut frame = vec![0_u8; header.size()];
frame[..HEADER_LEN].copy_from_slice(&header_bytes);
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
return Ok(frame);
}
// Compressed frame: read the 4-byte size prefix first.
if body_len < 4 {
return Err(TransportError::Protocol(
"compressed body must be at least 4 bytes for size prefix".to_string(),
));
}
let mut size_prefix = [0_u8; 4];
stream.read_exact(&mut size_prefix).await?;
let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
.map_err(|error| TransportError::Protocol(error.to_string()))?;
// Read the remaining compressed body in chunks.
let remaining = body_len - 4;
let mut total_read = 0_usize;
let mut chunk = vec![0_u8; 8192];
while total_read < remaining {
let want = (remaining - total_read).min(chunk.len());
stream.read_exact(&mut chunk[..want]).await?;
decompressor
.feed(&chunk[..want])
.map_err(|error| TransportError::Protocol(error.to_string()))?;
total_read += want;
}
if !decompressor.is_complete() {
return Err(TransportError::Protocol(
"streaming decompression did not complete after reading entire body".to_string(),
));
}
let decompressed = decompressor
.finish()
.map_err(|error| TransportError::Protocol(error.to_string()))?;
// Reconstruct as an uncompressed frame.
let new_size = HEADER_LEN + decompressed.len();
let new_header = qroissant_core::MessageHeader::new(
header.encoding(),
header.message_type(),
Compression::Uncompressed,
new_size,
)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
let mut frame = Vec::with_capacity(new_size);
frame.extend_from_slice(
&new_header
.to_bytes()
.map_err(|error| TransportError::Protocol(error.to_string()))?,
);
frame.extend_from_slice(&decompressed);
Ok(frame)
}
use qroissant_core::pipelined::PipelinedReader;
use qroissant_core::pipelined::decode_value_async;
use qroissant_core::value::Value;
pub async fn request_value_pipelined_over<R: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
conn: &mut R,
payload: &[u8],
) -> TransportResult<Value> {
conn.write_all(payload).await.map_err(TransportError::Io)?;
conn.flush().await.map_err(TransportError::Io)?;
let mut header_bytes = [0_u8; HEADER_LEN];
conn.read_exact(&mut header_bytes)
.await
.map_err(TransportError::Io)?;
let header =
MessageHeader::parse(&header_bytes).map_err(|e| TransportError::Protocol(e.to_string()))?;
let (decompressor, remaining_compressed) = if header.compression() != Compression::Uncompressed
{
let mut size_prefix = [0_u8; 4];
conn.read_exact(&mut size_prefix)
.await
.map_err(TransportError::Io)?;
let decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
.map_err(|e| TransportError::Protocol(e.to_string()))?;
(Some(decompressor), header.body_len() - 4)
} else {
(None, header.body_len())
};
let mut decomp_reader = DecompressingReader::new(conn, decompressor, remaining_compressed);
let mut pipelined_reader = PipelinedReader::new(&mut decomp_reader, header.encoding())
.map_err(|e| TransportError::Protocol(e.to_string()))?;
decode_value_async(&mut pipelined_reader)
.await
.map_err(|e| TransportError::Protocol(e.to_string()))
}