475 lines
15 KiB
Rust
475 lines
15 KiB
Rust
use std::pin::Pin;
|
|
use std::task::Context;
|
|
use std::task::Poll;
|
|
use std::time::Duration;
|
|
|
|
use qroissant_core::Compression;
|
|
use qroissant_core::HEADER_LEN;
|
|
use qroissant_core::MessageHeader;
|
|
use qroissant_core::StreamingDecompressor;
|
|
use qroissant_core::read_message_length;
|
|
use tokio::io::AsyncRead;
|
|
use tokio::io::AsyncReadExt;
|
|
use tokio::io::AsyncWrite;
|
|
use tokio::io::AsyncWriteExt;
|
|
use tokio::io::ReadBuf;
|
|
use tokio::net::TcpStream;
|
|
#[cfg(unix)]
|
|
use tokio::net::UnixStream;
|
|
|
|
use crate::TransportError;
|
|
use crate::TransportResult;
|
|
use crate::synchronous::CLIENT_CAPABILITY;
|
|
|
|
pub enum AsyncTransport {
|
|
Tcp(TcpStream),
|
|
#[cfg(unix)]
|
|
Unix(UnixStream),
|
|
}
|
|
|
|
impl AsyncRead for AsyncTransport {
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
match &mut *self {
|
|
Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf),
|
|
#[cfg(unix)]
|
|
Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for AsyncTransport {
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<std::io::Result<usize>> {
|
|
match &mut *self {
|
|
Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf),
|
|
#[cfg(unix)]
|
|
Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf),
|
|
}
|
|
}
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
match &mut *self {
|
|
Self::Tcp(stream) => Pin::new(stream).poll_flush(cx),
|
|
#[cfg(unix)]
|
|
Self::Unix(stream) => Pin::new(stream).poll_flush(cx),
|
|
}
|
|
}
|
|
|
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
match &mut *self {
|
|
Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx),
|
|
#[cfg(unix)]
|
|
Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsyncTransport {
|
|
pub async fn shutdown(&mut self) -> std::io::Result<()> {
|
|
match self {
|
|
Self::Tcp(stream) => stream.shutdown().await,
|
|
#[cfg(unix)]
|
|
Self::Unix(stream) => stream.shutdown().await,
|
|
}
|
|
}
|
|
|
|
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
|
|
match self {
|
|
Self::Tcp(stream) => stream.take_error(),
|
|
#[cfg(unix)]
|
|
Self::Unix(stream) => stream.take_error(),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct AsyncPooledTransport {
|
|
transport: AsyncTransport,
|
|
broken: bool,
|
|
}
|
|
|
|
impl AsyncPooledTransport {
|
|
pub fn new(transport: AsyncTransport) -> Self {
|
|
Self {
|
|
transport,
|
|
broken: false,
|
|
}
|
|
}
|
|
|
|
pub fn mark_broken(&mut self) {
|
|
self.broken = true;
|
|
}
|
|
|
|
pub fn is_broken(&self) -> bool {
|
|
self.broken || self.transport.take_error().ok().flatten().is_some()
|
|
}
|
|
|
|
pub fn transport_mut(&mut self) -> &mut AsyncTransport {
|
|
&mut self.transport
|
|
}
|
|
}
|
|
|
|
/// A reader that transparently decompresses q IPC payloads as they are read.
|
|
pub struct DecompressingReader<'a, R> {
|
|
reader: &'a mut R,
|
|
decompressor: Option<StreamingDecompressor>,
|
|
remaining_compressed: usize,
|
|
buffer: Vec<u8>,
|
|
}
|
|
|
|
impl<'a, R: AsyncRead + Unpin> DecompressingReader<'a, R> {
|
|
pub fn new(
|
|
reader: &'a mut R,
|
|
decompressor: Option<StreamingDecompressor>,
|
|
remaining_compressed: usize,
|
|
) -> Self {
|
|
Self {
|
|
reader,
|
|
decompressor,
|
|
remaining_compressed,
|
|
buffer: vec![0_u8; 8192],
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'a, R: AsyncRead + Unpin> AsyncRead for DecompressingReader<'a, R> {
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
let this = &mut *self;
|
|
|
|
if let Some(decompressor) = &mut this.decompressor {
|
|
// If we have decompressed data available, yield it first.
|
|
if decompressor.unread_len() > 0 {
|
|
let chunk = decompressor.next_chunk();
|
|
let to_copy = chunk.len().min(buf.remaining());
|
|
buf.put_slice(&chunk[..to_copy]);
|
|
decompressor.consume(to_copy);
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
|
|
// If decompression is complete and no more unread bytes, EOF.
|
|
if decompressor.is_complete() {
|
|
return Poll::Ready(Ok(()));
|
|
}
|
|
|
|
// Otherwise, read more compressed data from the underlying reader.
|
|
if this.remaining_compressed > 0 {
|
|
let want = this.remaining_compressed.min(this.buffer.len());
|
|
let mut read_buf = ReadBuf::new(&mut this.buffer[..want]);
|
|
match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) {
|
|
Poll::Ready(Ok(())) => {
|
|
let read = read_buf.filled().len();
|
|
if read == 0 && want > 0 {
|
|
return Poll::Ready(Err(std::io::Error::new(
|
|
std::io::ErrorKind::UnexpectedEof,
|
|
"unexpected EOF reading compressed body",
|
|
)));
|
|
}
|
|
this.remaining_compressed -= read;
|
|
decompressor.feed(read_buf.filled()).map_err(|e| {
|
|
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
|
|
})?;
|
|
|
|
// Recursive call to yield the newly decompressed bytes.
|
|
return self.poll_read(cx, buf);
|
|
}
|
|
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
|
|
Poll::Pending => return Poll::Pending,
|
|
}
|
|
}
|
|
Poll::Ready(Ok(()))
|
|
} else {
|
|
// Uncompressed path: direct read from underlying reader.
|
|
Pin::new(&mut this.reader).poll_read(cx, buf)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl AsyncRead for AsyncPooledTransport {
|
|
fn poll_read(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &mut ReadBuf<'_>,
|
|
) -> Poll<std::io::Result<()>> {
|
|
Pin::new(&mut self.transport).poll_read(cx, buf)
|
|
}
|
|
}
|
|
|
|
impl AsyncWrite for AsyncPooledTransport {
|
|
fn poll_write(
|
|
mut self: Pin<&mut Self>,
|
|
cx: &mut Context<'_>,
|
|
buf: &[u8],
|
|
) -> Poll<std::io::Result<usize>> {
|
|
Pin::new(&mut self.transport).poll_write(cx, buf)
|
|
}
|
|
|
|
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
Pin::new(&mut self.transport).poll_flush(cx)
|
|
}
|
|
|
|
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
|
|
Pin::new(&mut self.transport).poll_shutdown(cx)
|
|
}
|
|
}
|
|
|
|
fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec<u8> {
|
|
let username = username.unwrap_or_default();
|
|
let password = password.unwrap_or_default();
|
|
let mut bytes = format!("{username}:{password}").into_bytes();
|
|
bytes.push(CLIENT_CAPABILITY);
|
|
bytes.push(0);
|
|
bytes
|
|
}
|
|
|
|
fn timeout_error(context: &str, timeout_ms: u64) -> TransportError {
|
|
TransportError::Io(std::io::Error::new(
|
|
std::io::ErrorKind::TimedOut,
|
|
format!("{context} timed out after {timeout_ms}ms"),
|
|
))
|
|
}
|
|
|
|
async fn run_with_timeout<T, F>(
|
|
timeout_ms: Option<u64>,
|
|
context: &str,
|
|
future: F,
|
|
) -> TransportResult<T>
|
|
where
|
|
F: std::future::Future<Output = std::io::Result<T>>,
|
|
{
|
|
match timeout_ms {
|
|
Some(timeout_ms) => tokio::time::timeout(Duration::from_millis(timeout_ms), future)
|
|
.await
|
|
.map_err(|_| timeout_error(context, timeout_ms))?
|
|
.map_err(TransportError::Io),
|
|
None => future.await.map_err(TransportError::Io),
|
|
}
|
|
}
|
|
|
|
async fn perform_handshake<S>(
|
|
stream: &mut S,
|
|
username: Option<&str>,
|
|
password: Option<&str>,
|
|
timeout_ms: Option<u64>,
|
|
) -> TransportResult<u8>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
run_with_timeout(
|
|
timeout_ms,
|
|
"q IPC handshake write",
|
|
stream.write_all(&credentials_bytes(username, password)),
|
|
)
|
|
.await?;
|
|
run_with_timeout(timeout_ms, "q IPC handshake flush", stream.flush()).await?;
|
|
|
|
let mut capability = [0_u8; 1];
|
|
run_with_timeout(
|
|
timeout_ms,
|
|
"q IPC handshake read",
|
|
stream.read_exact(&mut capability),
|
|
)
|
|
.await?;
|
|
Ok(capability[0])
|
|
}
|
|
|
|
pub async fn connect_tcp_transport(
|
|
host: &str,
|
|
port: u16,
|
|
username: Option<&str>,
|
|
password: Option<&str>,
|
|
timeout_ms: Option<u64>,
|
|
) -> TransportResult<AsyncTransport> {
|
|
let mut stream =
|
|
run_with_timeout(timeout_ms, "TCP connect", TcpStream::connect((host, port))).await?;
|
|
stream.set_nodelay(true)?;
|
|
perform_handshake(&mut stream, username, password, timeout_ms).await?;
|
|
Ok(AsyncTransport::Tcp(stream))
|
|
}
|
|
|
|
#[cfg(unix)]
|
|
pub async fn connect_unix_transport(
|
|
path: &str,
|
|
username: Option<&str>,
|
|
password: Option<&str>,
|
|
timeout_ms: Option<u64>,
|
|
) -> TransportResult<AsyncTransport> {
|
|
let mut stream =
|
|
run_with_timeout(timeout_ms, "Unix socket connect", UnixStream::connect(path)).await?;
|
|
perform_handshake(&mut stream, username, password, timeout_ms).await?;
|
|
Ok(AsyncTransport::Unix(stream))
|
|
}
|
|
|
|
pub async fn read_frame<S>(stream: &mut S) -> TransportResult<Vec<u8>>
|
|
where
|
|
S: AsyncRead + Unpin,
|
|
{
|
|
let mut header = [0_u8; HEADER_LEN];
|
|
stream.read_exact(&mut header).await?;
|
|
let message_length = read_message_length(&header)
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
|
let mut frame = vec![0_u8; message_length];
|
|
frame[..HEADER_LEN].copy_from_slice(&header);
|
|
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
|
|
Ok(frame)
|
|
}
|
|
|
|
pub async fn request_frame_over<S>(stream: &mut S, payload: &[u8]) -> TransportResult<Vec<u8>>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
stream.write_all(payload).await?;
|
|
stream.flush().await?;
|
|
read_frame(stream).await
|
|
}
|
|
|
|
pub async fn begin_streaming_frame_over<S>(
|
|
stream: &mut S,
|
|
payload: &[u8],
|
|
) -> TransportResult<([u8; HEADER_LEN], usize)>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
stream.write_all(payload).await?;
|
|
stream.flush().await?;
|
|
|
|
let mut header = [0_u8; HEADER_LEN];
|
|
stream.read_exact(&mut header).await?;
|
|
let message_length = read_message_length(&header)
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
|
Ok((header, message_length - HEADER_LEN))
|
|
}
|
|
|
|
/// Async variant of [`crate::synchronous::request_frame_streaming_over`].
|
|
///
|
|
/// Sends a payload and reads the response frame, using streaming decompression
|
|
/// when the response is compressed.
|
|
pub async fn request_frame_streaming_over<S>(
|
|
stream: &mut S,
|
|
payload: &[u8],
|
|
) -> TransportResult<Vec<u8>>
|
|
where
|
|
S: AsyncRead + AsyncWrite + Unpin,
|
|
{
|
|
stream.write_all(payload).await?;
|
|
stream.flush().await?;
|
|
|
|
// Read the 8-byte header.
|
|
let mut header_bytes = [0_u8; HEADER_LEN];
|
|
stream.read_exact(&mut header_bytes).await?;
|
|
|
|
let header = MessageHeader::from_bytes(header_bytes)
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
|
let body_len = header.body_len();
|
|
|
|
if header.compression() == Compression::Uncompressed {
|
|
let mut frame = vec![0_u8; header.size()];
|
|
frame[..HEADER_LEN].copy_from_slice(&header_bytes);
|
|
stream.read_exact(&mut frame[HEADER_LEN..]).await?;
|
|
return Ok(frame);
|
|
}
|
|
|
|
// Compressed frame: read the 4-byte size prefix first.
|
|
if body_len < 4 {
|
|
return Err(TransportError::Protocol(
|
|
"compressed body must be at least 4 bytes for size prefix".to_string(),
|
|
));
|
|
}
|
|
|
|
let mut size_prefix = [0_u8; 4];
|
|
stream.read_exact(&mut size_prefix).await?;
|
|
|
|
let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
|
|
|
// Read the remaining compressed body in chunks.
|
|
let remaining = body_len - 4;
|
|
let mut total_read = 0_usize;
|
|
let mut chunk = vec![0_u8; 8192];
|
|
|
|
while total_read < remaining {
|
|
let want = (remaining - total_read).min(chunk.len());
|
|
stream.read_exact(&mut chunk[..want]).await?;
|
|
decompressor
|
|
.feed(&chunk[..want])
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
|
total_read += want;
|
|
}
|
|
|
|
if !decompressor.is_complete() {
|
|
return Err(TransportError::Protocol(
|
|
"streaming decompression did not complete after reading entire body".to_string(),
|
|
));
|
|
}
|
|
|
|
let decompressed = decompressor
|
|
.finish()
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
|
|
|
// Reconstruct as an uncompressed frame.
|
|
let new_size = HEADER_LEN + decompressed.len();
|
|
let new_header = qroissant_core::MessageHeader::new(
|
|
header.encoding(),
|
|
header.message_type(),
|
|
Compression::Uncompressed,
|
|
new_size,
|
|
)
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?;
|
|
|
|
let mut frame = Vec::with_capacity(new_size);
|
|
frame.extend_from_slice(
|
|
&new_header
|
|
.to_bytes()
|
|
.map_err(|error| TransportError::Protocol(error.to_string()))?,
|
|
);
|
|
frame.extend_from_slice(&decompressed);
|
|
Ok(frame)
|
|
}
|
|
use qroissant_core::pipelined::PipelinedReader;
|
|
use qroissant_core::pipelined::decode_value_async;
|
|
use qroissant_core::value::Value;
|
|
|
|
pub async fn request_value_pipelined_over<R: AsyncRead + AsyncWrite + Unpin + Send + 'static>(
|
|
conn: &mut R,
|
|
payload: &[u8],
|
|
) -> TransportResult<Value> {
|
|
conn.write_all(payload).await.map_err(TransportError::Io)?;
|
|
conn.flush().await.map_err(TransportError::Io)?;
|
|
|
|
let mut header_bytes = [0_u8; HEADER_LEN];
|
|
conn.read_exact(&mut header_bytes)
|
|
.await
|
|
.map_err(TransportError::Io)?;
|
|
let header =
|
|
MessageHeader::parse(&header_bytes).map_err(|e| TransportError::Protocol(e.to_string()))?;
|
|
|
|
let (decompressor, remaining_compressed) = if header.compression() != Compression::Uncompressed
|
|
{
|
|
let mut size_prefix = [0_u8; 4];
|
|
conn.read_exact(&mut size_prefix)
|
|
.await
|
|
.map_err(TransportError::Io)?;
|
|
let decompressor = StreamingDecompressor::new(size_prefix, header.encoding())
|
|
.map_err(|e| TransportError::Protocol(e.to_string()))?;
|
|
(Some(decompressor), header.body_len() - 4)
|
|
} else {
|
|
(None, header.body_len())
|
|
};
|
|
|
|
let mut decomp_reader = DecompressingReader::new(conn, decompressor, remaining_compressed);
|
|
let mut pipelined_reader = PipelinedReader::new(&mut decomp_reader, header.encoding())
|
|
.map_err(|e| TransportError::Protocol(e.to_string()))?;
|
|
|
|
decode_value_async(&mut pipelined_reader)
|
|
.await
|
|
.map_err(|e| TransportError::Protocol(e.to_string()))
|
|
}
|