Vendor qroissant 0.3.0 baseline

This commit is contained in:
Cam Zalewski 2026-05-20 14:11:30 +01:00
commit 53ac90fe84
56 changed files with 18309 additions and 0 deletions

View file

@ -0,0 +1,475 @@
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use qroissant_core::Compression;
use qroissant_core::HEADER_LEN;
use qroissant_core::MessageHeader;
use qroissant_core::StreamingDecompressor;
use qroissant_core::read_message_length;
use tokio::io::AsyncRead;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;
#[cfg(unix)]
use tokio::net::UnixStream;
use crate::TransportError;
use crate::TransportResult;
use crate::synchronous::CLIENT_CAPABILITY;
pub enum AsyncTransport {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
impl AsyncRead for AsyncTransport {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for AsyncTransport {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_flush(cx),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match &mut *self {
Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
#[cfg(unix)]
Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
impl AsyncTransport {
pub async fn shutdown(&mut self) -> std::io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown().await,
#[cfg(unix)]
Self::Unix(stream) => stream.shutdown().await,
}
}
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
match self {
Self::Tcp(stream) => stream.take_error(),
#[cfg(unix)]
Self::Unix(stream) => stream.take_error(),
}
}
}
pub struct AsyncPooledTransport {
transport: AsyncTransport,
broken: bool,
}
impl AsyncPooledTransport {
pub fn new(transport: AsyncTransport) -> Self {
Self {
transport,
broken: false,
}
}
pub fn mark_broken(&mut self) {
self.broken = true;
}
pub fn is_broken(&self) -> bool {
self.broken || self.transport.take_error().ok().flatten().is_some()
}
pub fn transport_mut(&mut self) -> &mut AsyncTransport {
&mut self.transport
}
}
/// A reader that transparently decompresses q IPC payloads as they are read.
pub struct DecompressingReader<'a, R> {
reader: &'a mut R,
decompressor: Option<StreamingDecompressor>,
remaining_compressed: usize,
buffer: Vec<u8>,
}
impl<'a, R: AsyncRead + Unpin> DecompressingReader<'a, R> {
pub fn new(
reader: &'a mut R,
decompressor: Option<StreamingDecompressor>,
remaining_compressed: usize,
) -> Self {
Self {
reader,
decompressor,
remaining_compressed,
buffer: vec![0_u8; 8192],
}
}
}
impl<'a, R: AsyncRead + Unpin> AsyncRead for DecompressingReader<'a, R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
let this = &mut *self;
if let Some(decompressor) = &mut this.decompressor {
// If we have decompressed data available, yield it first.
if decompressor.unread_len() > 0 {
let chunk = decompressor.next_chunk();
let to_copy = chunk.len().min(buf.remaining());
buf.put_slice(&chunk[..to_copy]);
decompressor.consume(to_copy);
return Poll::Ready(Ok(()));
}
// If decompression is complete and no more unread bytes, EOF.
if decompressor.is_complete() {
return Poll::Ready(Ok(()));
}
// Otherwise, read more compressed data from the underlying reader.
if this.remaining_compressed > 0 {
let want = this.remaining_compressed.min(this.buffer.len());
let mut read_buf = ReadBuf::new(&mut this.buffer[..want]);
match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
Poll::Ready(Ok(())) => {
let read = read_buf.filled().len();
if read == 0 && want > 0 {
return Poll::Ready(Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF reading compressed body",
)));
}
this.remaining_compressed -= read;
decompressor.feed(read_buf.filled()).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
})?;
// Recursive call to yield the newly decompressed bytes.
return self.poll_read(cx, buf);
}
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
}
}
Poll::Ready(Ok(()))
} else {
// Uncompressed path: direct read from underlying reader.
Pin::new(&mut this.reader).poll_read(cx, buf)
}
}
}
impl AsyncRead for AsyncPooledTransport {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.transport).poll_read(cx, buf)
}
}
impl AsyncWrite for AsyncPooledTransport {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
Pin::new(&mut self.transport).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.transport).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
Pin::new(&mut self.transport).poll_shutdown(cx)
}
}
fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec<u8> {
let username = username.unwrap_or_default();
let password = password.unwrap_or_default();
let mut bytes = format!("{username}:{password}").into_bytes();
bytes.push(CLIENT_CAPABILITY);
bytes.push(0);
bytes
}
fn timeout_error(context: &str, timeout_ms: u64) -> TransportError {
TransportError::Io(std::io::Error::new(
std::io::ErrorKind::TimedOut,
format!("{context} timed out after {timeout_ms}ms"),
))
}
async fn run_with_timeout<T, F>(
timeout_ms: Option<u64>,
context: &str,
future: F,
) -> TransportResult<T>
where
F: std::future::Future<Output = std::io::Result<T>>,
{
match timeout_ms {
Some(timeout_ms) => tokio::time::timeout(Duration::from_millis(timeout_ms), future)
.await
.map_err(|_| timeout_error(context, timeout_ms))?
.map_err(TransportError::Io),
None => future.await.map_err(TransportError::Io),
}
}
async fn perform_handshake<S>(
stream: &mut S,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<u8>
where
S: AsyncRead + AsyncWrite + Unpin,
{
run_with_timeout(
timeout_ms,
"q IPC handshake write",
stream.write_all(&credentials_bytes(username, password)),
)
.await?;
run_with_timeout(timeout_ms, "q IPC handshake flush", stream.flush()).await?;
let mut capability = [0_u8; 1];
run_with_timeout(
timeout_ms,
"q IPC handshake read",
stream.read_exact(&mut capability),
)
.await?;
Ok(capability[0])
}
pub async fn connect_tcp_transport(
host: &str,
port: u16,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<AsyncTransport> {
let mut stream =
run_with_timeout(timeout_ms, "TCP connect", TcpStream::connect((host, port))).await?;
stream.set_nodelay(true)?;
perform_handshake(&mut stream, username, password, timeout_ms).await?;
Ok(AsyncTransport::Tcp(stream))
}
#[cfg(unix)]
pub async fn connect_unix_transport(
path: &str,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<AsyncTransport> {
let mut stream =
run_with_timeout(timeout_ms, "Unix socket connect", UnixStream::connect(path)).await?;
perform_handshake(&mut stream, username, password, timeout_ms).await?;
Ok(AsyncTransport::Unix(stream))
}
pub async fn read_frame<S>(stream: &mut S) -> TransportResult<Vec<u8>>
where
S: AsyncRead + Unpin,
{
let mut header = [0_u8; HEADER_LEN];
stream.read_exact(&mut header).await?;
let message_length = read_message_length(&header)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
let mut frame = vec![0_u8; message_length];
frame[..HEADER_LEN].copy_from_slice(&header);
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
Ok(frame)
}
pub async fn request_frame_over<S>(stream: &mut S, payload: &[u8]) -> TransportResult<Vec<u8>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
stream.write_all(payload).await?;
stream.flush().await?;
read_frame(stream).await
}
pub async fn begin_streaming_frame_over<S>(
stream: &mut S,
payload: &[u8],
) -> TransportResult<([u8; HEADER_LEN], usize)>
where
S: AsyncRead + AsyncWrite + Unpin,
{
stream.write_all(payload).await?;
stream.flush().await?;
let mut header = [0_u8; HEADER_LEN];
stream.read_exact(&mut header).await?;
let message_length = read_message_length(&header)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
Ok((header, message_length - HEADER_LEN))
}
/// Async variant of [`crate::synchronous::request_frame_streaming_over`].
///
/// Sends a payload and reads the response frame, using streaming decompression
/// when the response is compressed.
pub async fn request_frame_streaming_over<S>(
stream: &mut S,
payload: &[u8],
) -> TransportResult<Vec<u8>>
where
S: AsyncRead + AsyncWrite + Unpin,
{
stream.write_all(payload).await?;
stream.flush().await?;
// Read the 8-byte header.
let mut header_bytes = [0_u8; HEADER_LEN];
stream.read_exact(&mut header_bytes).await?;
let header = MessageHeader::from_bytes(header_bytes)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
let body_len = header.body_len();
if header.compression() == Compression::Uncompressed {
let mut frame = vec![0_u8; header.size()];
frame[..HEADER_LEN].copy_from_slice(&header_bytes);
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
return Ok(frame);
}
// Compressed frame: read the 4-byte size prefix first.
if body_len < 4 {
return Err(TransportError::Protocol(
"compressed body must be at least 4 bytes for size prefix".to_string(),
));
}
let mut size_prefix = [0_u8; 4];
stream.read_exact(&mut size_prefix).await?;
let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
.map_err(|error| TransportError::Protocol(error.to_string()))?;
// Read the remaining compressed body in chunks.
let remaining = body_len - 4;
let mut total_read = 0_usize;
let mut chunk = vec![0_u8; 8192];
while total_read < remaining {
let want = (remaining - total_read).min(chunk.len());
stream.read_exact(&mut chunk[..want]).await?;
decompressor
.feed(&chunk[..want])
.map_err(|error| TransportError::Protocol(error.to_string()))?;
total_read += want;
}
if !decompressor.is_complete() {
return Err(TransportError::Protocol(
"streaming decompression did not complete after reading entire body".to_string(),
));
}
let decompressed = decompressor
.finish()
.map_err(|error| TransportError::Protocol(error.to_string()))?;
// Reconstruct as an uncompressed frame.
let new_size = HEADER_LEN + decompressed.len();
let new_header = qroissant_core::MessageHeader::new(
header.encoding(),
header.message_type(),
Compression::Uncompressed,
new_size,
)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
let mut frame = Vec::with_capacity(new_size);
frame.extend_from_slice(
&new_header
.to_bytes()
.map_err(|error| TransportError::Protocol(error.to_string()))?,
);
frame.extend_from_slice(&decompressed);
Ok(frame)
}
use qroissant_core::pipelined::PipelinedReader;
use qroissant_core::pipelined::decode_value_async;
use qroissant_core::value::Value;
pub async fn request_value_pipelined_over<R: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
conn: &mut R,
payload: &[u8],
) -> TransportResult<Value> {
conn.write_all(payload).await.map_err(TransportError::Io)?;
conn.flush().await.map_err(TransportError::Io)?;
let mut header_bytes = [0_u8; HEADER_LEN];
conn.read_exact(&mut header_bytes)
.await
.map_err(TransportError::Io)?;
let header =
MessageHeader::parse(&header_bytes).map_err(|e| TransportError::Protocol(e.to_string()))?;
let (decompressor, remaining_compressed) = if header.compression() != Compression::Uncompressed
{
let mut size_prefix = [0_u8; 4];
conn.read_exact(&mut size_prefix)
.await
.map_err(TransportError::Io)?;
let decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
.map_err(|e| TransportError::Protocol(e.to_string()))?;
(Some(decompressor), header.body_len() - 4)
} else {
(None, header.body_len())
};
let mut decomp_reader = DecompressingReader::new(conn, decompressor, remaining_compressed);
let mut pipelined_reader = PipelinedReader::new(&mut decomp_reader, header.encoding())
.map_err(|e| TransportError::Protocol(e.to_string()))?;
decode_value_async(&mut pipelined_reader)
.await
.map_err(|e| TransportError::Protocol(e.to_string()))
}

View file

@ -0,0 +1,42 @@
use std::fmt;
pub type TransportResult<T> = Result<T, TransportError>;
#[derive(Debug)]
pub enum TransportError {
Io(std::io::Error),
InvalidEndpoint(String),
InvalidQueryLength(usize),
Protocol(String),
Closed,
}
impl fmt::Display for TransportError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Io(error) => error.fmt(f),
Self::InvalidEndpoint(message) => write!(f, "{message}"),
Self::InvalidQueryLength(length) => write!(
f,
"q query string length {length} exceeds 32-bit q IPC capacity"
),
Self::Protocol(message) => write!(f, "{message}"),
Self::Closed => write!(f, "connection is closed"),
}
}
}
impl std::error::Error for TransportError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Self::Io(error) => Some(error),
_ => None,
}
}
}
impl From<std::io::Error> for TransportError {
fn from(value: std::io::Error) -> Self {
Self::Io(value)
}
}

View file

@ -0,0 +1,37 @@
//! Shared q IPC transport primitives.
mod asynchronous;
mod error;
mod synchronous;
pub use asynchronous::AsyncPooledTransport;
pub use asynchronous::AsyncTransport;
pub use asynchronous::begin_streaming_frame_over as begin_streaming_frame_over_async;
pub use asynchronous::connect_tcp_transport as connect_tcp_transport_async;
#[cfg(unix)]
pub use asynchronous::connect_unix_transport as connect_unix_transport_async;
pub use asynchronous::read_frame as read_frame_async;
pub use asynchronous::request_frame_over as request_frame_over_async;
pub use asynchronous::request_frame_streaming_over as request_frame_streaming_over_async;
pub use asynchronous::request_value_pipelined_over as request_value_pipelined_over_async;
pub use error::TransportError;
pub use error::TransportResult;
pub use qroissant_core::HEADER_LEN as QIPC_HEADER_LEN;
pub use synchronous::CLIENT_CAPABILITY;
pub use synchronous::SyncConnection;
pub use synchronous::SyncPooledTransport;
pub use synchronous::SyncTransport;
pub use synchronous::begin_streaming_frame_over;
pub use synchronous::connect_tcp_transport;
#[cfg(unix)]
pub use synchronous::connect_unix_transport;
pub use synchronous::credentials_bytes;
pub use synchronous::encode_sync_query;
pub use synchronous::extract_q_error;
pub use synchronous::parse_message_header;
pub use synchronous::perform_handshake;
pub use synchronous::request_frame_over;
pub use synchronous::request_frame_streaming_over;
pub use synchronous::validate_response_frame;
pub use synchronous::validate_response_header;
pub use synchronous::validate_response_header_bytes;

View file

@ -0,0 +1,420 @@
use std::io::Read;
use std::io::Write;
use std::net::Shutdown;
use std::net::TcpStream;
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use std::time::Duration;
use qroissant_core::Attribute;
use qroissant_core::Compression;
use qroissant_core::Encoding;
use qroissant_core::Frame;
use qroissant_core::HEADER_LEN;
use qroissant_core::MessageHeader;
use qroissant_core::MessageType;
use qroissant_core::StreamingDecompressor;
use qroissant_core::Value;
use qroissant_core::Vector;
use qroissant_core::VectorData;
use qroissant_core::encode_message;
use qroissant_core::read_frame;
use qroissant_core::read_message_length;
use crate::TransportError;
use crate::TransportResult;
pub const CLIENT_CAPABILITY: u8 = 3;
pub enum SyncTransport {
Tcp(TcpStream),
#[cfg(unix)]
Unix(UnixStream),
}
impl SyncTransport {
pub fn shutdown(&mut self) -> std::io::Result<()> {
match self {
Self::Tcp(stream) => stream.shutdown(Shutdown::Both),
#[cfg(unix)]
Self::Unix(stream) => stream.shutdown(Shutdown::Both),
}
}
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
match self {
Self::Tcp(stream) => stream.take_error(),
#[cfg(unix)]
Self::Unix(stream) => stream.take_error(),
}
}
pub fn set_timeouts(&self, timeout_ms: Option<u64>) -> std::io::Result<()> {
let timeout = timeout_ms.map(Duration::from_millis);
match self {
Self::Tcp(stream) => {
stream.set_read_timeout(timeout)?;
stream.set_write_timeout(timeout)?;
stream.set_nodelay(true)
}
#[cfg(unix)]
Self::Unix(stream) => {
stream.set_read_timeout(timeout)?;
stream.set_write_timeout(timeout)
}
}
}
}
impl Read for SyncTransport {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match self {
Self::Tcp(stream) => stream.read(buf),
#[cfg(unix)]
Self::Unix(stream) => stream.read(buf),
}
}
}
impl Write for SyncTransport {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
match self {
Self::Tcp(stream) => stream.write(buf),
#[cfg(unix)]
Self::Unix(stream) => stream.write(buf),
}
}
fn flush(&mut self) -> std::io::Result<()> {
match self {
Self::Tcp(stream) => stream.flush(),
#[cfg(unix)]
Self::Unix(stream) => stream.flush(),
}
}
}
pub struct SyncPooledTransport {
transport: SyncTransport,
broken: bool,
}
impl SyncPooledTransport {
pub fn new(transport: SyncTransport) -> Self {
Self {
transport,
broken: false,
}
}
pub fn mark_broken(&mut self) {
self.broken = true;
}
pub fn is_broken(&self) -> bool {
self.broken || self.transport.take_error().ok().flatten().is_some()
}
pub fn shutdown(&mut self) -> std::io::Result<()> {
self.transport.shutdown()
}
}
impl Read for SyncPooledTransport {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
self.transport.read(buf)
}
}
impl Write for SyncPooledTransport {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.transport.write(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
self.transport.flush()
}
}
pub fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec<u8> {
let username = username.unwrap_or_default();
let password = password.unwrap_or_default();
let mut bytes = format!("{username}:{password}").into_bytes();
bytes.push(CLIENT_CAPABILITY);
bytes.push(0);
bytes
}
pub fn perform_handshake<S: Read + Write>(
stream: &mut S,
username: Option<&str>,
password: Option<&str>,
) -> TransportResult<u8> {
stream.write_all(&credentials_bytes(username, password))?;
stream.flush()?;
let mut capability = [0_u8; 1];
stream.read_exact(&mut capability)?;
Ok(capability[0])
}
pub fn encode_sync_query(message: &str) -> TransportResult<Vec<u8>> {
let _ = i32::try_from(message.len())
.map_err(|_| TransportError::InvalidQueryLength(message.len()))?;
let value = Value::Vector(Vector::new(
Attribute::None,
VectorData::Char(bytes::Bytes::copy_from_slice(message.as_bytes())),
));
encode_message(
&value,
Encoding::LittleEndian,
MessageType::Synchronous,
Compression::Uncompressed,
)
.map_err(|error| TransportError::Protocol(error.to_string()))
}
pub fn extract_q_error(frame_bytes: &[u8]) -> TransportResult<Option<String>> {
let frame =
Frame::parse(frame_bytes).map_err(|error| TransportError::Protocol(error.to_string()))?;
let body = frame.body();
if body.first().copied() != Some(128) {
return Ok(None);
}
let message = match body[1..].iter().position(|byte| *byte == 0) {
Some(end) => &body[1..1 + end],
None => &body[1..],
};
Ok(Some(String::from_utf8_lossy(message).into_owned()))
}
pub fn parse_message_header(header_bytes: [u8; HEADER_LEN]) -> TransportResult<MessageHeader> {
MessageHeader::from_bytes(header_bytes)
.map_err(|error| TransportError::Protocol(error.to_string()))
}
pub fn validate_response_header(header: MessageHeader) -> TransportResult<()> {
if header.message_type() != MessageType::Response {
return Err(TransportError::Protocol(format!(
"expected a q response frame, received {:?}",
header.message_type()
)));
}
Ok(())
}
pub fn validate_response_header_bytes(
header_bytes: [u8; HEADER_LEN],
) -> TransportResult<MessageHeader> {
let header = parse_message_header(header_bytes)?;
validate_response_header(header)?;
Ok(header)
}
pub fn validate_response_frame(frame_bytes: &[u8]) -> TransportResult<MessageHeader> {
let frame =
Frame::parse(frame_bytes).map_err(|error| TransportError::Protocol(error.to_string()))?;
let header = frame.header();
validate_response_header(header)?;
Ok(header)
}
pub fn connect_tcp_transport(
host: &str,
port: u16,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<SyncTransport> {
let mut stream = SyncTransport::Tcp(TcpStream::connect((host, port))?);
stream.set_timeouts(timeout_ms)?;
perform_handshake(&mut stream, username, password)?;
Ok(stream)
}
#[cfg(unix)]
pub fn connect_unix_transport(
path: &str,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<SyncTransport> {
let mut stream = SyncTransport::Unix(UnixStream::connect(path)?);
stream.set_timeouts(timeout_ms)?;
perform_handshake(&mut stream, username, password)?;
Ok(stream)
}
pub fn request_frame_over<S: Read + Write>(
stream: &mut S,
payload: &[u8],
) -> TransportResult<Vec<u8>> {
stream.write_all(payload)?;
stream.flush()?;
read_frame(stream).map_err(|error| TransportError::Protocol(error.to_string()))
}
/// Sends a payload and reads the response frame, using streaming decompression
/// when the response is compressed.
///
/// For compressed frames, the body is read in chunks and fed to a
/// [`StreamingDecompressor`] incrementally, overlapping network I/O with
/// decompression work. The returned frame is reconstructed as an
/// *uncompressed* frame so callers can decode it normally.
///
/// For uncompressed frames, this behaves identically to [`request_frame_over`].
pub fn request_frame_streaming_over<S: Read + Write>(
stream: &mut S,
payload: &[u8],
) -> TransportResult<Vec<u8>> {
stream.write_all(payload)?;
stream.flush()?;
// Read the 8-byte header.
let mut header_bytes = [0_u8; HEADER_LEN];
stream.read_exact(&mut header_bytes)?;
let header = parse_message_header(header_bytes)?;
let body_len = header.body_len();
if header.compression() == Compression::Uncompressed {
// Fast path: read entire uncompressed body.
let mut frame = vec![0_u8; header.size()];
frame[..HEADER_LEN].copy_from_slice(&header_bytes);
stream.read_exact(&mut frame[HEADER_LEN..])?;
return Ok(frame);
}
// Compressed frame: read the 4-byte size prefix first.
if body_len < 4 {
return Err(TransportError::Protocol(
"compressed body must be at least 4 bytes for size prefix".to_string(),
));
}
let mut size_prefix = [0_u8; 4];
stream.read_exact(&mut size_prefix)?;
let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
.map_err(|error| TransportError::Protocol(error.to_string()))?;
// Read the remaining compressed body in chunks.
let remaining = body_len - 4;
let mut total_read = 0_usize;
let mut chunk = [0_u8; 8192];
while total_read < remaining {
let want = (remaining - total_read).min(chunk.len());
stream.read_exact(&mut chunk[..want])?;
decompressor
.feed(&chunk[..want])
.map_err(|error| TransportError::Protocol(error.to_string()))?;
total_read += want;
}
if !decompressor.is_complete() {
return Err(TransportError::Protocol(
"streaming decompression did not complete after reading entire body".to_string(),
));
}
let decompressed = decompressor
.finish()
.map_err(|error| TransportError::Protocol(error.to_string()))?;
// Reconstruct as an uncompressed frame: header + decompressed body.
let new_size = HEADER_LEN + decompressed.len();
let new_header = MessageHeader::new(
header.encoding(),
header.message_type(),
Compression::Uncompressed,
new_size,
)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
let mut frame = Vec::with_capacity(new_size);
frame.extend_from_slice(
&new_header
.to_bytes()
.map_err(|error| TransportError::Protocol(error.to_string()))?,
);
frame.extend_from_slice(&decompressed);
Ok(frame)
}
pub fn begin_streaming_frame_over<S: Read + Write>(
stream: &mut S,
payload: &[u8],
) -> TransportResult<([u8; HEADER_LEN], usize)> {
stream.write_all(payload)?;
stream.flush()?;
let mut header = [0_u8; HEADER_LEN];
stream.read_exact(&mut header)?;
let message_length = read_message_length(&header)
.map_err(|error| TransportError::Protocol(error.to_string()))?;
Ok((header, message_length - HEADER_LEN))
}
pub struct SyncConnection {
transport: Option<SyncTransport>,
}
impl SyncConnection {
pub fn connect_tcp(
host: &str,
port: u16,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<Self> {
Ok(Self {
transport: Some(connect_tcp_transport(
host, port, username, password, timeout_ms,
)?),
})
}
#[cfg(unix)]
pub fn connect_unix(
path: &str,
username: Option<&str>,
password: Option<&str>,
timeout_ms: Option<u64>,
) -> TransportResult<Self> {
Ok(Self {
transport: Some(connect_unix_transport(
path, username, password, timeout_ms,
)?),
})
}
pub fn query_frame(&mut self, message: &str) -> TransportResult<Vec<u8>> {
let payload = encode_sync_query(message)?;
let transport = self.transport.as_mut().ok_or(TransportError::Closed)?;
let frame = request_frame_over(transport, &payload)?;
validate_response_frame(&frame)?;
Ok(frame)
}
pub fn is_closed(&self) -> bool {
self.transport.is_none()
}
pub fn close(&mut self) -> TransportResult<()> {
let Some(mut transport) = self.transport.take() else {
return Ok(());
};
transport.shutdown()?;
Ok(())
}
}
impl Drop for SyncConnection {
fn drop(&mut self) {
let _ = self.close();
}
}