use std::pin::Pin; use std::task::Context; use std::task::Poll; use std::time::Duration; use qroissant_core::Compression; use qroissant_core::HEADER_LEN; use qroissant_core::MessageHeader; use qroissant_core::StreamingDecompressor; use qroissant_core::read_message_length; use tokio::io::AsyncRead; use tokio::io::AsyncReadExt; use tokio::io::AsyncWrite; use tokio::io::AsyncWriteExt; use tokio::io::ReadBuf; use tokio::net::TcpStream; #[cfg(unix)] use tokio::net::UnixStream; use crate::TransportError; use crate::TransportResult; use crate::synchronous::CLIENT_CAPABILITY; pub enum AsyncTransport { Tcp(TcpStream), #[cfg(unix)] Unix(UnixStream), } impl AsyncRead for AsyncTransport { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { match &mut *self { Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf), #[cfg(unix)] Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf), } } } impl AsyncWrite for AsyncTransport { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match &mut *self { Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf), #[cfg(unix)] Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { Self::Tcp(stream) => Pin::new(stream).poll_flush(cx), #[cfg(unix)] Self::Unix(stream) => Pin::new(stream).poll_flush(cx), } } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx), #[cfg(unix)] Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx), } } } impl AsyncTransport { pub async fn shutdown(&mut self) -> std::io::Result<()> { match self { Self::Tcp(stream) => stream.shutdown().await, #[cfg(unix)] Self::Unix(stream) => stream.shutdown().await, } } pub fn take_error(&self) -> std::io::Result> { match self { Self::Tcp(stream) => stream.take_error(), #[cfg(unix)] Self::Unix(stream) => stream.take_error(), } } } pub struct AsyncPooledTransport { transport: AsyncTransport, broken: bool, } impl AsyncPooledTransport { pub fn new(transport: AsyncTransport) -> Self { Self { transport, broken: false, } } pub fn mark_broken(&mut self) { self.broken = true; } pub fn is_broken(&self) -> bool { self.broken || self.transport.take_error().ok().flatten().is_some() } pub fn transport_mut(&mut self) -> &mut AsyncTransport { &mut self.transport } } /// A reader that transparently decompresses q IPC payloads as they are read. pub struct DecompressingReader<'a, R> { reader: &'a mut R, decompressor: Option, remaining_compressed: usize, buffer: Vec, } impl<'a, R: AsyncRead + Unpin> DecompressingReader<'a, R> { pub fn new( reader: &'a mut R, decompressor: Option, remaining_compressed: usize, ) -> Self { Self { reader, decompressor, remaining_compressed, buffer: vec![0_u8; 8192], } } } impl<'a, R: AsyncRead + Unpin> AsyncRead for DecompressingReader<'a, R> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { let this = &mut *self; if let Some(decompressor) = &mut this.decompressor { // If we have decompressed data available, yield it first. if decompressor.unread_len() > 0 { let chunk = decompressor.next_chunk(); let to_copy = chunk.len().min(buf.remaining()); buf.put_slice(&chunk[..to_copy]); decompressor.consume(to_copy); return Poll::Ready(Ok(())); } // If decompression is complete and no more unread bytes, EOF. if decompressor.is_complete() { return Poll::Ready(Ok(())); } // Otherwise, read more compressed data from the underlying reader. if this.remaining_compressed > 0 { let want = this.remaining_compressed.min(this.buffer.len()); let mut read_buf = ReadBuf::new(&mut this.buffer[..want]); match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) { Poll::Ready(Ok(())) => { let read = read_buf.filled().len(); if read == 0 && want > 0 { return Poll::Ready(Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "unexpected EOF reading compressed body", ))); } this.remaining_compressed -= read; decompressor.feed(read_buf.filled()).map_err(|e| { std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()) })?; // Recursive call to yield the newly decompressed bytes. return self.poll_read(cx, buf); } Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), Poll::Pending => return Poll::Pending, } } Poll::Ready(Ok(())) } else { // Uncompressed path: direct read from underlying reader. Pin::new(&mut this.reader).poll_read(cx, buf) } } } impl AsyncRead for AsyncPooledTransport { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { Pin::new(&mut self.transport).poll_read(cx, buf) } } impl AsyncWrite for AsyncPooledTransport { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut self.transport).poll_write(cx, buf) } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.transport).poll_flush(cx) } fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.transport).poll_shutdown(cx) } } fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec { let username = username.unwrap_or_default(); let password = password.unwrap_or_default(); let mut bytes = format!("{username}:{password}").into_bytes(); bytes.push(CLIENT_CAPABILITY); bytes.push(0); bytes } fn timeout_error(context: &str, timeout_ms: u64) -> TransportError { TransportError::Io(std::io::Error::new( std::io::ErrorKind::TimedOut, format!("{context} timed out after {timeout_ms}ms"), )) } async fn run_with_timeout( timeout_ms: Option, context: &str, future: F, ) -> TransportResult where F: std::future::Future>, { match timeout_ms { Some(timeout_ms) => tokio::time::timeout(Duration::from_millis(timeout_ms), future) .await .map_err(|_| timeout_error(context, timeout_ms))? .map_err(TransportError::Io), None => future.await.map_err(TransportError::Io), } } async fn perform_handshake( stream: &mut S, username: Option<&str>, password: Option<&str>, timeout_ms: Option, ) -> TransportResult where S: AsyncRead + AsyncWrite + Unpin, { run_with_timeout( timeout_ms, "q IPC handshake write", stream.write_all(&credentials_bytes(username, password)), ) .await?; run_with_timeout(timeout_ms, "q IPC handshake flush", stream.flush()).await?; let mut capability = [0_u8; 1]; run_with_timeout( timeout_ms, "q IPC handshake read", stream.read_exact(&mut capability), ) .await?; Ok(capability[0]) } pub async fn connect_tcp_transport( host: &str, port: u16, username: Option<&str>, password: Option<&str>, timeout_ms: Option, ) -> TransportResult { let mut stream = run_with_timeout(timeout_ms, "TCP connect", TcpStream::connect((host, port))).await?; stream.set_nodelay(true)?; perform_handshake(&mut stream, username, password, timeout_ms).await?; Ok(AsyncTransport::Tcp(stream)) } #[cfg(unix)] pub async fn connect_unix_transport( path: &str, username: Option<&str>, password: Option<&str>, timeout_ms: Option, ) -> TransportResult { let mut stream = run_with_timeout(timeout_ms, "Unix socket connect", UnixStream::connect(path)).await?; perform_handshake(&mut stream, username, password, timeout_ms).await?; Ok(AsyncTransport::Unix(stream)) } pub async fn read_frame(stream: &mut S) -> TransportResult> where S: AsyncRead + Unpin, { let mut header = [0_u8; HEADER_LEN]; stream.read_exact(&mut header).await?; let message_length = read_message_length(&header) .map_err(|error| TransportError::Protocol(error.to_string()))?; let mut frame = vec![0_u8; message_length]; frame[..HEADER_LEN].copy_from_slice(&header); stream.read_exact(&mut frame[HEADER_LEN..]).await?; Ok(frame) } pub async fn request_frame_over(stream: &mut S, payload: &[u8]) -> TransportResult> where S: AsyncRead + AsyncWrite + Unpin, { stream.write_all(payload).await?; stream.flush().await?; read_frame(stream).await } pub async fn begin_streaming_frame_over( stream: &mut S, payload: &[u8], ) -> TransportResult<([u8; HEADER_LEN], usize)> where S: AsyncRead + AsyncWrite + Unpin, { stream.write_all(payload).await?; stream.flush().await?; let mut header = [0_u8; HEADER_LEN]; stream.read_exact(&mut header).await?; let message_length = read_message_length(&header) .map_err(|error| TransportError::Protocol(error.to_string()))?; Ok((header, message_length - HEADER_LEN)) } /// Async variant of [`crate::synchronous::request_frame_streaming_over`]. /// /// Sends a payload and reads the response frame, using streaming decompression /// when the response is compressed. pub async fn request_frame_streaming_over( stream: &mut S, payload: &[u8], ) -> TransportResult> where S: AsyncRead + AsyncWrite + Unpin, { stream.write_all(payload).await?; stream.flush().await?; // Read the 8-byte header. let mut header_bytes = [0_u8; HEADER_LEN]; stream.read_exact(&mut header_bytes).await?; let header = MessageHeader::from_bytes(header_bytes) .map_err(|error| TransportError::Protocol(error.to_string()))?; let body_len = header.body_len(); if header.compression() == Compression::Uncompressed { let mut frame = vec![0_u8; header.size()]; frame[..HEADER_LEN].copy_from_slice(&header_bytes); stream.read_exact(&mut frame[HEADER_LEN..]).await?; return Ok(frame); } // Compressed frame: read the 4-byte size prefix first. if body_len < 4 { return Err(TransportError::Protocol( "compressed body must be at least 4 bytes for size prefix".to_string(), )); } let mut size_prefix = [0_u8; 4]; stream.read_exact(&mut size_prefix).await?; let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding()) .map_err(|error| TransportError::Protocol(error.to_string()))?; // Read the remaining compressed body in chunks. let remaining = body_len - 4; let mut total_read = 0_usize; let mut chunk = vec![0_u8; 8192]; while total_read < remaining { let want = (remaining - total_read).min(chunk.len()); stream.read_exact(&mut chunk[..want]).await?; decompressor .feed(&chunk[..want]) .map_err(|error| TransportError::Protocol(error.to_string()))?; total_read += want; } if !decompressor.is_complete() { return Err(TransportError::Protocol( "streaming decompression did not complete after reading entire body".to_string(), )); } let decompressed = decompressor .finish() .map_err(|error| TransportError::Protocol(error.to_string()))?; // Reconstruct as an uncompressed frame. let new_size = HEADER_LEN + decompressed.len(); let new_header = qroissant_core::MessageHeader::new( header.encoding(), header.message_type(), Compression::Uncompressed, new_size, ) .map_err(|error| TransportError::Protocol(error.to_string()))?; let mut frame = Vec::with_capacity(new_size); frame.extend_from_slice( &new_header .to_bytes() .map_err(|error| TransportError::Protocol(error.to_string()))?, ); frame.extend_from_slice(&decompressed); Ok(frame) } use qroissant_core::pipelined::PipelinedReader; use qroissant_core::pipelined::decode_value_async; use qroissant_core::value::Value; pub async fn request_value_pipelined_over( conn: &mut R, payload: &[u8], ) -> TransportResult { conn.write_all(payload).await.map_err(TransportError::Io)?; conn.flush().await.map_err(TransportError::Io)?; let mut header_bytes = [0_u8; HEADER_LEN]; conn.read_exact(&mut header_bytes) .await .map_err(TransportError::Io)?; let header = MessageHeader::parse(&header_bytes).map_err(|e| TransportError::Protocol(e.to_string()))?; let (decompressor, remaining_compressed) = if header.compression() != Compression::Uncompressed { let mut size_prefix = [0_u8; 4]; conn.read_exact(&mut size_prefix) .await .map_err(TransportError::Io)?; let decompressor = StreamingDecompressor::new(size_prefix, header.encoding()) .map_err(|e| TransportError::Protocol(e.to_string()))?; (Some(decompressor), header.body_len() - 4) } else { (None, header.body_len()) }; let mut decomp_reader = DecompressingReader::new(conn, decompressor, remaining_compressed); let mut pipelined_reader = PipelinedReader::new(&mut decomp_reader, header.encoding()) .map_err(|e| TransportError::Protocol(e.to_string()))?; decode_value_async(&mut pipelined_reader) .await .map_err(|e| TransportError::Protocol(e.to_string())) }