From 53ac90fe84c77951e2084580413957ebc8143841 Mon Sep 17 00:00:00 2001 From: CamZalewski Date: Wed, 20 May 2026 14:11:30 +0100 Subject: [PATCH] Vendor qroissant 0.3.0 baseline --- .gitignore | 31 + Cargo.lock | 1403 +++++++++++++++ Cargo.toml | 26 + PKG-INFO | 246 +++ README.md | 226 +++ crates/qroissant-arrow/Cargo.toml | 24 + crates/qroissant-arrow/src/error.rs | 23 + crates/qroissant-arrow/src/ingestion.rs | 1576 ++++++++++++++++ crates/qroissant-arrow/src/lib.rs | 26 + crates/qroissant-arrow/src/metadata.rs | 90 + crates/qroissant-arrow/src/options.rs | 85 + crates/qroissant-arrow/src/projection.rs | 1160 ++++++++++++ crates/qroissant-core/Cargo.toml | 19 + crates/qroissant-core/src/decode.rs | 907 ++++++++++ crates/qroissant-core/src/encode.rs | 385 ++++ crates/qroissant-core/src/error.rs | 112 ++ crates/qroissant-core/src/extent.rs | 518 ++++++ crates/qroissant-core/src/frame.rs | 826 +++++++++ crates/qroissant-core/src/lib.rs | 61 + crates/qroissant-core/src/pipelined.rs | 390 ++++ crates/qroissant-core/src/protocol.rs | 373 ++++ crates/qroissant-core/src/value.rs | 479 +++++ crates/qroissant-kernels/Cargo.toml | 11 + crates/qroissant-kernels/src/boolean.rs | 121 ++ crates/qroissant-kernels/src/lib.rs | 25 + crates/qroissant-kernels/src/nulls.rs | 371 ++++ crates/qroissant-kernels/src/temporal.rs | 317 ++++ crates/qroissant-python/Cargo.toml | 27 + crates/qroissant-python/src/client.rs | 1597 +++++++++++++++++ crates/qroissant-python/src/errors.rs | 114 ++ crates/qroissant-python/src/lib.rs | 28 + crates/qroissant-python/src/raw_response.rs | 777 ++++++++ crates/qroissant-python/src/repr/cell.rs | 437 +++++ crates/qroissant-python/src/repr/format.rs | 278 +++ crates/qroissant-python/src/repr/mod.rs | 26 + crates/qroissant-python/src/repr/options.rs | 172 ++ crates/qroissant-python/src/repr/render.rs | 80 + crates/qroissant-python/src/serde.rs | 215 +++ crates/qroissant-python/src/types.rs | 1325 ++++++++++++++ crates/qroissant-python/src/values.rs | 925 ++++++++++ crates/qroissant-transport/Cargo.toml | 16 + .../qroissant-transport/src/asynchronous.rs | 475 +++++ crates/qroissant-transport/src/error.rs | 42 + crates/qroissant-transport/src/lib.rs | 37 + crates/qroissant-transport/src/synchronous.rs | 420 +++++ pyproject.toml | 59 + python/qroissant/__init__.py | 68 + python/qroissant/__init__.pyi | 50 + python/qroissant/_client.pyi | 453 +++++ python/qroissant/_config.pyi | 416 +++++ python/qroissant/_errors.pyi | 24 + python/qroissant/_message.pyi | 43 + python/qroissant/_repr.pyi | 67 + python/qroissant/_serde.pyi | 72 + python/qroissant/_values.pyi | 234 +++ python/qroissant/py.typed | 1 + 56 files changed, 18309 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 PKG-INFO create mode 100644 README.md create mode 100644 crates/qroissant-arrow/Cargo.toml create mode 100644 crates/qroissant-arrow/src/error.rs create mode 100644 crates/qroissant-arrow/src/ingestion.rs create mode 100644 crates/qroissant-arrow/src/lib.rs create mode 100644 crates/qroissant-arrow/src/metadata.rs create mode 100644 crates/qroissant-arrow/src/options.rs create mode 100644 crates/qroissant-arrow/src/projection.rs create mode 100644 crates/qroissant-core/Cargo.toml create mode 100644 crates/qroissant-core/src/decode.rs create mode 100644 crates/qroissant-core/src/encode.rs create mode 100644 crates/qroissant-core/src/error.rs create mode 100644 crates/qroissant-core/src/extent.rs create mode 100644 crates/qroissant-core/src/frame.rs create mode 100644 crates/qroissant-core/src/lib.rs create mode 100644 crates/qroissant-core/src/pipelined.rs create mode 100644 crates/qroissant-core/src/protocol.rs create mode 100644 crates/qroissant-core/src/value.rs create mode 100644 crates/qroissant-kernels/Cargo.toml create mode 100644 crates/qroissant-kernels/src/boolean.rs create mode 100644 crates/qroissant-kernels/src/lib.rs create mode 100644 crates/qroissant-kernels/src/nulls.rs create mode 100644 crates/qroissant-kernels/src/temporal.rs create mode 100644 crates/qroissant-python/Cargo.toml create mode 100644 crates/qroissant-python/src/client.rs create mode 100644 crates/qroissant-python/src/errors.rs create mode 100644 crates/qroissant-python/src/lib.rs create mode 100644 crates/qroissant-python/src/raw_response.rs create mode 100644 crates/qroissant-python/src/repr/cell.rs create mode 100644 crates/qroissant-python/src/repr/format.rs create mode 100644 crates/qroissant-python/src/repr/mod.rs create mode 100644 crates/qroissant-python/src/repr/options.rs create mode 100644 crates/qroissant-python/src/repr/render.rs create mode 100644 crates/qroissant-python/src/serde.rs create mode 100644 crates/qroissant-python/src/types.rs create mode 100644 crates/qroissant-python/src/values.rs create mode 100644 crates/qroissant-transport/Cargo.toml create mode 100644 crates/qroissant-transport/src/asynchronous.rs create mode 100644 crates/qroissant-transport/src/error.rs create mode 100644 crates/qroissant-transport/src/lib.rs create mode 100644 crates/qroissant-transport/src/synchronous.rs create mode 100644 pyproject.toml create mode 100644 python/qroissant/__init__.py create mode 100644 python/qroissant/__init__.pyi create mode 100644 python/qroissant/_client.pyi create mode 100644 python/qroissant/_config.pyi create mode 100644 python/qroissant/_errors.pyi create mode 100644 python/qroissant/_message.pyi create mode 100644 python/qroissant/_repr.pyi create mode 100644 python/qroissant/_serde.pyi create mode 100644 python/qroissant/_values.pyi create mode 100644 python/qroissant/py.typed diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6bcf35d --- /dev/null +++ b/.gitignore @@ -0,0 +1,31 @@ +# Rust +/target/ +**/target/ +**/*.rs.bk +Cargo.lock.bak + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +.eggs/ +build/ +dist/ +.venv/ +venv/ + +# Maturin +*.so +*.pyd +*.dylib + +# Editors / OS +.idea/ +.vscode/ +*.swp +.DS_Store +Thumbs.db + +# Local consumer code dropped in for repro, not part of qroissant +/document.py diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..65347bf --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,1403 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "const-random", + "getrandom 0.3.4", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "arrow-array" +version = "58.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e53796e07a6525edaf7dc28b540d477a934aff14af97967ad1d5550878969b9e" +dependencies = [ + "ahash", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "chrono-tz", + "half", + "hashbrown", + "num-complex", + "num-integer", + "num-traits", +] + +[[package]] +name = "arrow-buffer" +version = "58.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2c1a85bb2e94ee10b76531d8bc3ce9b7b4c0d508cabfb17d477f63f2617bd20" +dependencies = [ + "bytes", + "half", + "num-bigint", + "num-traits", +] + +[[package]] +name = "arrow-cast" +version = "58.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89fb245db6b0e234ed8e15b644edb8664673fefe630575e94e62cd9d489a8a26" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-ord", + "arrow-schema", + "arrow-select", + "atoi", + "base64", + "chrono", + "comfy-table", + "half", + "lexical-core", + "num-traits", + "ryu", +] + +[[package]] +name = "arrow-data" +version = "58.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "189d210bc4244c715fa3ed9e6e22864673cccb73d5da28c2723fb2e527329b33" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", + "num-integer", + "num-traits", +] + +[[package]] +name = "arrow-ord" +version = "58.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "211136cb253577ee1a6665f741a13136d4e563f64f5093ffd6fb837af90b9495" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", +] + +[[package]] +name = "arrow-schema" +version = "58.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b47e0ca91cc438d2c7879fe95e0bca5329fff28649e30a88c6f760b1faeddcb" +dependencies = [ + "bitflags", +] + +[[package]] +name = "arrow-select" +version = "58.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "750a7d1dda177735f5e82a314485b6915c7cccdbb278262ac44090f4aba4a325" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num-traits", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bb8" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "457d7ed3f888dfd2c7af56d4975cade43c622f74bdcddfed6d4352f57acc6310" +dependencies = [ + "futures-util", + "parking_lot", + "portable-atomic", + "tokio", +] + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytecount" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "175812e0be2bccb6abe50bb8d566126198344f707e304f45c648fd8f2cc0365e" + +[[package]] +name = "bytemuck" +version = "1.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8efb64bd706a16a1bdde310ae86b351e4d21550d98d056f22f8a7f7a2183fec" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9abbd1bc6865053c427f7198e6af43bfdedc55ab791faed4fbd361d789575ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a0dd1ca384932ff3641c8718a02769f1698e7563dc6974ffd03346116310423" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "chrono-tz" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" +dependencies = [ + "chrono", + "phf", +] + +[[package]] +name = "comfy-table" +version = "7.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958c5d6ecf1f214b4c2bbbbf6ab9523a864bd136dcf71a7e8904799acfe1ad47" +dependencies = [ + "unicode-segmentation", + "unicode-width", +] + +[[package]] +name = "const-random" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom 0.2.17", + "once_cell", + "tiny-keccak", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crunchy" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460fbee9c2c2f33933d720630a6a0bac33ba7053db5344fac858d4b8952d77d5" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "futures" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b147ee9d1f6d097cef9ce628cd2ee62288d963e16fb287bd9286455b241382d" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07bbe89c50d7a535e539b8c17bc0b49bdb77747034daa8087407d655f3f7cc1d" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" + +[[package]] +name = "futures-macro" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e835b70203e41293343137df5c0664546da5745f82ec9b84d40be8336958447b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "half" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ea2d84b969582b4b1864a92dc5d27cd2b77b622a8d79306834f1be5ba20d84b" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", + "zerocopy", +] + +[[package]] +name = "hashbrown" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "indexmap" +version = "2.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "js-sys" +version = "0.3.91" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b49715b7073f385ba4bc528e5747d02e66cb39c6146efb66b781f131f0fb399c" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "lexical-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d8d125a277f807e55a77304455eb7b1cb52f2b18c143b60e766c120bd64a594" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52a9f232fbd6f550bc0137dcb5f99ab674071ac2d690ac69704593cb4abbea56" +dependencies = [ + "lexical-parse-integer", + "lexical-util", +] + +[[package]] +name = "lexical-parse-integer" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a7a039f8fb9c19c996cd7b2fcce303c1b2874fe1aca544edc85c4a5f8489b34" +dependencies = [ + "lexical-util", +] + +[[package]] +name = "lexical-util" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2604dd126bb14f13fb5d1bd6a66155079cb9fa655b37f875b3a742c705dbed17" + +[[package]] +name = "lexical-write-float" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50c438c87c013188d415fbabbb1dceb44249ab81664efbd31b14ae55dabb6361" +dependencies = [ + "lexical-util", + "lexical-write-integer", +] + +[[package]] +name = "lexical-write-integer" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "409851a618475d2d5796377cad353802345cba92c867d9fbcde9cf4eac4e14df" +dependencies = [ + "lexical-util", +] + +[[package]] +name = "libc" +version = "0.2.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b646652bf6661599e1da8901b3b9522896f01e736bad5f723fe7a3a27f899d" + +[[package]] +name = "libm" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "numpy" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "778da78c64ddc928ebf5ad9df5edf0789410ff3bdbf3619aed51cd789a6af1e2" +dependencies = [ + "half", + "libc", + "ndarray", + "num-complex", + "num-integer", + "num-traits", + "pyo3", + "pyo3-build-config", + "rustc-hash", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "papergrid" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b0f8def1f117e13c895f3eda65a7b5650688da29d6ad04635f61bc7b92eebd" +dependencies = [ + "bytecount", + "fnv", + "unicode-width", +] + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall", + "smallvec", + "windows-link", +] + +[[package]] +name = "phf" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" +dependencies = [ + "phf_shared", +] + +[[package]] +name = "phf_shared" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06005508882fb681fd97892ecff4b7fd0fee13ef1aa569f8695dae7ab9099981" +dependencies = [ + "siphasher", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf85e27e86080aafd5a22eae58a162e133a589551542b3e5cee4beb27e54f8e1" +dependencies = [ + "chrono", + "chrono-tz", + "indexmap", + "libc", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", +] + +[[package]] +name = "pyo3-arrow" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0360400036dda3db3d69102ef7e9646e4cd946c75a2d1d41fb8fd39879312636" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-schema", + "arrow-select", + "chrono", + "chrono-tz", + "half", + "indexmap", + "numpy", + "pyo3", + "thiserror 1.0.69", +] + +[[package]] +name = "pyo3-async-runtimes" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e7364a95bf00e8377bbf9b0f09d7ff9715a29d8fcf93b47d1a967363b973178" +dependencies = [ + "futures-channel", + "futures-util", + "once_cell", + "pin-project-lite", + "pyo3", + "tokio", +] + +[[package]] +name = "pyo3-build-config" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8bf94ee265674bf76c09fa430b0e99c26e319c945d96ca0d5a8215f31bf81cf7" +dependencies = [ + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "491aa5fc66d8059dd44a75f4580a2962c1862a1c2945359db36f6c2818b748dc" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f5d671734e9d7a43449f8480f8b38115df67bef8d21f76837fa75ee7aaa5e52e" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22faaa1ce6c430a1f71658760497291065e6450d7b5dc2bcf254d49f66ee700a" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "qroissant-arrow" +version = "0.3.0" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-schema", + "arrow-select", + "bytemuck", + "bytes", + "chrono", + "qroissant-core", + "qroissant-kernels", + "rayon", + "thiserror 2.0.18", +] + +[[package]] +name = "qroissant-core" +version = "0.3.0" +dependencies = [ + "bytemuck", + "bytes", + "futures", + "memchr", + "rayon", + "tokio", +] + +[[package]] +name = "qroissant-kernels" +version = "0.3.0" + +[[package]] +name = "qroissant-python" +version = "0.3.0" +dependencies = [ + "bb8", + "bytes", + "chrono", + "pyo3", + "pyo3-arrow", + "pyo3-async-runtimes", + "qroissant-arrow", + "qroissant-core", + "qroissant-kernels", + "qroissant-transport", + "r2d2", + "tabled", + "thiserror 2.0.18", + "tokio", +] + +[[package]] +name = "qroissant-transport" +version = "0.3.0" +dependencies = [ + "bytes", + "futures", + "qroissant-core", + "tokio", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r2d2" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51de85fb3fb6524929c8a2eb85e6b6d363de4e8c48f9e2c2eac4944abc181c93" +dependencies = [ + "log", + "parking_lot", + "scheduled-thread-pool", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "ryu" +version = "1.0.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9774ba4a74de5f7b1c1451ed6cd5285a32eddb5cccb8cc655a4e50009e06477f" + +[[package]] +name = "scheduled-thread-pool" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3cbc66816425a074528352f5789333ecff06ca41b36b0b0efdfbb29edc391a19" +dependencies = [ + "parking_lot", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "siphasher" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tabled" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6709222f3973137427ce50559cd564dc187a95b9cfe01613d2f4e93610e510a" +dependencies = [ + "papergrid", + "tabled_derive", +] + +[[package]] +name = "tabled_derive" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "931be476627d4c54070a1f3a9739ccbfec9b36b39815106a20cce2243bbcefe1" +dependencies = [ + "heck 0.4.1", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "target-lexicon" +version = "0.13.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb6935a6f5c20170eeceb1a3835a49e12e19d792f6dd344ccc76a985ca5a6ca" + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + +[[package]] +name = "tokio" +version = "1.50.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "27ad5e34374e03cfffefc301becb44e9dc3c17584f414349ebe29ed26661822d" +dependencies = [ + "bytes", + "libc", + "mio", + "parking_lot", + "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c55a2eff8b69ce66c84f85e1da1c233edc36ceb85a2058d11b0d6a3c7e7569c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6532f9a5c1ece3798cb1c2cfdba640b9b3ba884f5db45973a6f442510a87d38e" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18a2d50fcf105fb33bb15f00e7a77b772945a2ee45dcf454961fd843e74c18e6" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03ce4caeaac547cdf713d280eda22a730824dd11e6b8c3ca9e42247b25c631e3" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn 2.0.117", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75a326b8c223ee17883a4251907455a2431acc2791c98c26279376490c378c16" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "zerocopy" +version = "0.8.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "efbb2a062be311f2ba113ce66f697a4dc589f85e78a4aea276200804cea0ed87" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e8bc7269b54418e7aeeef514aa68f8690b8c0489a06b0136e5f57c4c5ccab89" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..cf28ecc --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,26 @@ +[workspace] +members = ["crates/qroissant-core", "crates/qroissant-kernels", "crates/qroissant-arrow", "crates/qroissant-transport", "crates/qroissant-python"] +resolver = "3" + +[workspace.package] +version = "0.3.0" +edition = "2024" +license = "Apache-2.0" +repository = "https://github.com/qroissant/qroissant" + +[profile.release] +lto = "fat" +codegen-units = 1 +opt-level = 3 + +[workspace.dependencies] +pyo3 = "0.28.2" +tokio = { version = "1.48.0", features = [ + "io-util", + "net", + "rt-multi-thread", + "sync", + "time", + "macros", +] } +futures = "0.3" diff --git a/PKG-INFO b/PKG-INFO new file mode 100644 index 0000000..4361bac --- /dev/null +++ b/PKG-INFO @@ -0,0 +1,246 @@ +Metadata-Version: 2.4 +Name: qroissant +Version: 0.3.0 +Classifier: Development Status :: 3 - Alpha +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Programming Language :: Python :: 3.13 +Classifier: Programming Language :: Rust +Summary: q/kdb+ IPC client library with Arrow-native Python interoperability +Keywords: kdb,q,ipc,arrow,pyo3 +Author: qroissant contributors +License-Expression: Apache-2.0 +Requires-Python: >=3.10 +Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM + +# qroissant + +qroissant is a minimal q/kdb+ IPC client library with first-class support for the Apache Arrow ecosystem. + +- **Lightweight** — qroissant is a minimal library weighing in at less than 4 MiB with no required dependencies. +- **Fast** — qroissant is written in Rust, a safe and high-performance systems programming language. Moreover, qroissant uses your system resources to the best extent possible by leveraging zero-copy, multithreading, and other vectorization techniques such as SIMD. +- **Modular** — qroissant relies heavily on the [Apache Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html) for communicating with other libraries from the Apache Arrow ecosystem with zero-copy. This includes pyarrow, polars, duckdb, pandas, datafusion, and more. +- **Type hints** — qroissant provides type annotations for all of its functionality. + +--- + +## Installation + +```bash +pip install qroissant +``` + +Requires Python 3.10+. Wheels are available for Linux (x86\_64, aarch64), macOS (universal2), and Windows (x86\_64). + +--- + +## Quick start + +### Connect and query + +```python +import qroissant as q + +endpoint = q.Endpoint.tcp("localhost", 5000) + +with q.Connection(endpoint) as conn: + result = conn.query("select from trade where date = .z.d") + print(result) # Table +``` + +### To Arrow / Polars / PyArrow + +Decoded values implement the Arrow PyCapsule protocol — pass them straight to any Arrow-aware library: + +```python +import polars as pl +import pyarrow as pa + +with q.Connection(endpoint) as conn: + table = conn.query("select from trade") + +# zero-copy — no intermediate Python objects +df = pl.from_arrow(table) +pa_table = pa.RecordBatch.from_batches([pa.record_batch(table)]) +``` + +### Async + +```python +import asyncio +import qroissant as q + +async def main(): + endpoint = q.Endpoint.tcp("localhost", 5000) + async with q.AsyncConnection(endpoint) as conn: + result = await conn.query("1 + 1") + print(result) # Atom → 2 + +asyncio.run(main()) +``` + +### Connection pool + +```python +pool_opts = q.PoolOptions( + max_size=10, + min_idle=2, + checkout_timeout_ms=5_000, + test_on_checkout=True, +) + +with q.Pool(endpoint, pool=pool_opts) as pool: + pool.prewarm() # open idle connections eagerly + result = pool.query("count trade") # checked out and returned automatically + print(pool.metrics()) # PoolMetrics(connections=2, idle=2, …) +``` + +### Streaming raw response + +For large results you can stream the raw IPC bytes before decoding: + +```python +with q.Connection(endpoint) as conn: + with conn.query("select from trade", raw=True) as resp: + print(resp.header) # MessageHeader(size=…, compression=…) + value = resp.decode() # decode on demand +``` + +### Standalone encode / decode + +```python +# decode an IPC payload you already have +payload: bytes = ... +value = q.decode(payload) + +# encode a value back to IPC bytes +frame = q.encode(value, message_type=q.MessageType.SYNCHRONOUS) +``` + +--- + +## Value types + +Every `conn.query()` call returns a `Value` subclass: + +| q type | Python type | Arrow export | +|--------|------------|--------------| +| scalar (atom) | `Atom` | `__arrow_c_array__` | +| typed list | `Vector` | `__arrow_c_array__` | +| mixed list | `List` | `__arrow_c_array__` | +| dictionary | `Dictionary` | `__arrow_c_array__` (StructArray) | +| table | `Table` | `__arrow_c_stream__` | + +--- + +## Decode options + +Control how IPC data is projected into Arrow: + +```python +opts = ( + q.DecodeOptions.builder() + .with_symbol_interpretation(q.SymbolInterpretation.DICTIONARY) # dict-encode symbols + .with_temporal_nulls(True) # map q null sentinels → None + .with_treat_infinity_as_null(True) # map ±∞ → None + .with_parallel(True) # decode table columns in parallel + .build() +) + +with q.Connection(endpoint, options=opts) as conn: + result = conn.query("select from trade") +``` + +--- + +## Endpoints + +```python +# TCP +endpoint = q.Endpoint.tcp( + "localhost", 5000, + username="user", + password="pass", + timeout_ms=3_000, +) + +# Unix domain socket +endpoint = q.Endpoint.unix( + "/tmp/qroissant.sock", + username="user", + password="pass", +) +``` + +--- + +## Error handling + +```python +from qroissant import ( + QroissantError, # base class + DecodeError, # malformed IPC payload + ProtocolError, # bad frame header + TransportError, # socket / IO failure + QRuntimeError, # q process returned an error + PoolError, # pool management failure + PoolClosedError, # operation on a closed pool +) + +try: + result = conn.query("invalid expression") +except q.QRuntimeError as e: + print(f"q error: {e}") +except q.TransportError as e: + print(f"connection lost: {e}") +``` + +--- + +## Architecture + +qroissant is organized as a Rust workspace with strict crate boundaries: + +``` +crates/ +├── qroissant-core # q protocol, value types, encode/decode +├── qroissant-transport # sync & async TCP/Unix socket connections +├── qroissant-arrow # zero-copy Arrow projection +├── qroissant-kernels # SIMD / nightly-sensitive hot paths +└── qroissant-python # PyO3 bindings (the _native extension module) +``` + +The Python package at `python/qroissant/` re-exports everything from the compiled `_native` extension. The `.pyi` stub files in that directory define the public API contract. + +--- + +## Development + +```bash +# Install Python dependencies +uv sync --group dev --group docs + +# Build the Rust extension (required before running Python tests) +uv run maturin develop + +# Run tests +uv run pytest +cargo test --workspace + +# Lint and format +uv run ruff check python/ tests/ +cargo fmt --all +``` + +Transport integration tests require a q binary. Set `Q_BIN` to the path of your q executable before running `pytest`. + +--- + +## License + +Apache 2.0 — see [LICENSE](LICENSE). + diff --git a/README.md b/README.md new file mode 100644 index 0000000..635204f --- /dev/null +++ b/README.md @@ -0,0 +1,226 @@ +# qroissant + +qroissant is a minimal q/kdb+ IPC client library with first-class support for the Apache Arrow ecosystem. + +- **Lightweight** — qroissant is a minimal library weighing in at less than 4 MiB with no required dependencies. +- **Fast** — qroissant is written in Rust, a safe and high-performance systems programming language. Moreover, qroissant uses your system resources to the best extent possible by leveraging zero-copy, multithreading, and other vectorization techniques such as SIMD. +- **Modular** — qroissant relies heavily on the [Apache Arrow PyCapsule Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html) for communicating with other libraries from the Apache Arrow ecosystem with zero-copy. This includes pyarrow, polars, duckdb, pandas, datafusion, and more. +- **Type hints** — qroissant provides type annotations for all of its functionality. + +--- + +## Installation + +```bash +pip install qroissant +``` + +Requires Python 3.10+. Wheels are available for Linux (x86\_64, aarch64), macOS (universal2), and Windows (x86\_64). + +--- + +## Quick start + +### Connect and query + +```python +import qroissant as q + +endpoint = q.Endpoint.tcp("localhost", 5000) + +with q.Connection(endpoint) as conn: + result = conn.query("select from trade where date = .z.d") + print(result) # Table +``` + +### To Arrow / Polars / PyArrow + +Decoded values implement the Arrow PyCapsule protocol — pass them straight to any Arrow-aware library: + +```python +import polars as pl +import pyarrow as pa + +with q.Connection(endpoint) as conn: + table = conn.query("select from trade") + +# zero-copy — no intermediate Python objects +df = pl.from_arrow(table) +pa_table = pa.RecordBatch.from_batches([pa.record_batch(table)]) +``` + +### Async + +```python +import asyncio +import qroissant as q + +async def main(): + endpoint = q.Endpoint.tcp("localhost", 5000) + async with q.AsyncConnection(endpoint) as conn: + result = await conn.query("1 + 1") + print(result) # Atom → 2 + +asyncio.run(main()) +``` + +### Connection pool + +```python +pool_opts = q.PoolOptions( + max_size=10, + min_idle=2, + checkout_timeout_ms=5_000, + test_on_checkout=True, +) + +with q.Pool(endpoint, pool=pool_opts) as pool: + pool.prewarm() # open idle connections eagerly + result = pool.query("count trade") # checked out and returned automatically + print(pool.metrics()) # PoolMetrics(connections=2, idle=2, …) +``` + +### Streaming raw response + +For large results you can stream the raw IPC bytes before decoding: + +```python +with q.Connection(endpoint) as conn: + with conn.query("select from trade", raw=True) as resp: + print(resp.header) # MessageHeader(size=…, compression=…) + value = resp.decode() # decode on demand +``` + +### Standalone encode / decode + +```python +# decode an IPC payload you already have +payload: bytes = ... +value = q.decode(payload) + +# encode a value back to IPC bytes +frame = q.encode(value, message_type=q.MessageType.SYNCHRONOUS) +``` + +--- + +## Value types + +Every `conn.query()` call returns a `Value` subclass: + +| q type | Python type | Arrow export | +|--------|------------|--------------| +| scalar (atom) | `Atom` | `__arrow_c_array__` | +| typed list | `Vector` | `__arrow_c_array__` | +| mixed list | `List` | `__arrow_c_array__` | +| dictionary | `Dictionary` | `__arrow_c_array__` (StructArray) | +| table | `Table` | `__arrow_c_stream__` | + +--- + +## Decode options + +Control how IPC data is projected into Arrow: + +```python +opts = ( + q.DecodeOptions.builder() + .with_symbol_interpretation(q.SymbolInterpretation.DICTIONARY) # dict-encode symbols + .with_temporal_nulls(True) # map q null sentinels → None + .with_treat_infinity_as_null(True) # map ±∞ → None + .with_parallel(True) # decode table columns in parallel + .build() +) + +with q.Connection(endpoint, options=opts) as conn: + result = conn.query("select from trade") +``` + +--- + +## Endpoints + +```python +# TCP +endpoint = q.Endpoint.tcp( + "localhost", 5000, + username="user", + password="pass", + timeout_ms=3_000, +) + +# Unix domain socket +endpoint = q.Endpoint.unix( + "/tmp/qroissant.sock", + username="user", + password="pass", +) +``` + +--- + +## Error handling + +```python +from qroissant import ( + QroissantError, # base class + DecodeError, # malformed IPC payload + ProtocolError, # bad frame header + TransportError, # socket / IO failure + QRuntimeError, # q process returned an error + PoolError, # pool management failure + PoolClosedError, # operation on a closed pool +) + +try: + result = conn.query("invalid expression") +except q.QRuntimeError as e: + print(f"q error: {e}") +except q.TransportError as e: + print(f"connection lost: {e}") +``` + +--- + +## Architecture + +qroissant is organized as a Rust workspace with strict crate boundaries: + +``` +crates/ +├── qroissant-core # q protocol, value types, encode/decode +├── qroissant-transport # sync & async TCP/Unix socket connections +├── qroissant-arrow # zero-copy Arrow projection +├── qroissant-kernels # SIMD / nightly-sensitive hot paths +└── qroissant-python # PyO3 bindings (the _native extension module) +``` + +The Python package at `python/qroissant/` re-exports everything from the compiled `_native` extension. The `.pyi` stub files in that directory define the public API contract. + +--- + +## Development + +```bash +# Install Python dependencies +uv sync --group dev --group docs + +# Build the Rust extension (required before running Python tests) +uv run maturin develop + +# Run tests +uv run pytest +cargo test --workspace + +# Lint and format +uv run ruff check python/ tests/ +cargo fmt --all +``` + +Transport integration tests require a q binary. Set `Q_BIN` to the path of your q executable before running `pytest`. + +--- + +## License + +Apache 2.0 — see [LICENSE](LICENSE). diff --git a/crates/qroissant-arrow/Cargo.toml b/crates/qroissant-arrow/Cargo.toml new file mode 100644 index 0000000..395c0d8 --- /dev/null +++ b/crates/qroissant-arrow/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "qroissant-arrow" +version.workspace = true +edition.workspace = true +license.workspace = true +publish = false + +[lib] +name = "qroissant_arrow" +path = "src/lib.rs" + +[dependencies] +arrow-array = "58.0.0" +arrow-buffer = "58.0.0" +arrow-schema = "58.0.0" +arrow-select = "58.0.0" +bytemuck = { version = "1", features = ["derive", "extern_crate_alloc"] } +bytes = "1.11.1" +chrono = "0.4.44" +qroissant-core = { path = "../qroissant-core" } +qroissant-kernels = { path = "../qroissant-kernels" } +rayon = "1.10" +thiserror = "2.0.18" + diff --git a/crates/qroissant-arrow/src/error.rs b/crates/qroissant-arrow/src/error.rs new file mode 100644 index 0000000..0965c89 --- /dev/null +++ b/crates/qroissant-arrow/src/error.rs @@ -0,0 +1,23 @@ +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ProjectionError { + #[error("Arrow projection is not supported for {0}")] + Unsupported(String), + + #[error("Arrow projection failed: {0}")] + Arrow(String), +} + +pub type ProjectionResult = Result; + +#[derive(Debug, Error)] +pub enum IngestionError { + #[error("Arrow ingestion is not supported: {0}")] + Unsupported(String), + + #[error("Arrow ingestion failed: {0}")] + Arrow(#[from] arrow_schema::ArrowError), +} + +pub type IngestionResult = Result; diff --git a/crates/qroissant-arrow/src/ingestion.rs b/crates/qroissant-arrow/src/ingestion.rs new file mode 100644 index 0000000..7f41143 --- /dev/null +++ b/crates/qroissant-arrow/src/ingestion.rs @@ -0,0 +1,1576 @@ +//! Arrow ingestion: converts Arrow arrays and record batches into q `Value` trees. +//! +//! This is the reverse direction of [`crate::projection`]. Arrow field +//! metadata produced by the projection layer (`qroissant.shape`, +//! `qroissant.primitive`, etc.) is consumed here so that round-trips through +//! Arrow preserve exact q semantics. +//! +//! No PyO3 or Python dependencies are allowed in this crate; PyCapsule +//! handling lives in `qroissant-python`. + +use arrow_array::Array; +use arrow_array::ArrayRef; +use arrow_array::BinaryArray; +use arrow_array::BinaryViewArray; +use arrow_array::BooleanArray; +use arrow_array::Date32Array; +use arrow_array::DurationMicrosecondArray; +use arrow_array::DurationMillisecondArray; +use arrow_array::DurationNanosecondArray; +use arrow_array::DurationSecondArray; +use arrow_array::FixedSizeBinaryArray; +use arrow_array::Float32Array; +use arrow_array::Float64Array; +use arrow_array::Int16Array; +use arrow_array::Int32Array; +use arrow_array::Int64Array; +use arrow_array::LargeBinaryArray; +use arrow_array::LargeListArray; +use arrow_array::LargeStringArray; +use arrow_array::ListArray; +use arrow_array::MapArray; +use arrow_array::RecordBatch; +use arrow_array::StringArray; +use arrow_array::StringViewArray; +use arrow_array::StructArray; +use arrow_array::Time32MillisecondArray; +use arrow_array::Time32SecondArray; +use arrow_array::Time64MicrosecondArray; +use arrow_array::Time64NanosecondArray; +use arrow_array::TimestampMicrosecondArray; +use arrow_array::TimestampMillisecondArray; +use arrow_array::TimestampNanosecondArray; +use arrow_array::TimestampSecondArray; +use arrow_array::UInt8Array; +use arrow_schema::DataType; +use arrow_schema::Field as ArrowField; +use arrow_schema::SchemaRef; +use arrow_schema::TimeUnit; +use qroissant_core::Atom; +use qroissant_core::Attribute; +use qroissant_core::Dictionary; +use qroissant_core::List; +use qroissant_core::Table; +use qroissant_core::Value; +use qroissant_core::Vector; +use qroissant_core::VectorData; +use qroissant_kernels::nulls::Q_NULL_DATE; +use qroissant_kernels::nulls::Q_NULL_MINUTE; +use qroissant_kernels::nulls::Q_NULL_SECOND; +use qroissant_kernels::nulls::Q_NULL_SHORT; +use qroissant_kernels::nulls::Q_NULL_TIME; +use qroissant_kernels::nulls::Q_NULL_TIMESPAN; +use qroissant_kernels::nulls::Q_NULL_TIMESTAMP; +use qroissant_kernels::temporal::DATE_OFFSET_DAYS; +use qroissant_kernels::temporal::TIMESTAMP_OFFSET_NS; + +use crate::error::IngestionError; +use crate::error::IngestionResult; + +/// Converts a `Vec` to `bytes::Bytes` via zero-copy reinterpretation. +fn vec_to_bytes(values: Vec) -> bytes::Bytes { + // Safety: bytemuck::cast_vec requires NoUninit, which guarantees no padding. + let byte_vec: Vec = bytemuck::allocation::cast_vec(values); + bytes::Bytes::from(byte_vec) +} +use crate::metadata::ATTRIBUTE_KEY; +use crate::metadata::PRIMITIVE_KEY; +use crate::metadata::SHAPE_KEY; +use crate::metadata::SORTED_KEY; + +// --------------------------------------------------------------------------- +// Metadata hint extraction +// --------------------------------------------------------------------------- + +#[derive(Clone, Copy, Default, Debug)] +struct IngestHint { + shape: Option, + primitive: Option, + attribute: Option, + sorted: Option, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum IngestShape { + Atom, + Vector, + List, + Dictionary, + Table, + UnaryPrimitive, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum IngestPrimitive { + Boolean, + Guid, + Byte, + Short, + Int, + Long, + Real, + Float, + Char, + Symbol, + Timestamp, + Month, + Date, + Datetime, + Timespan, + Minute, + Second, + Time, +} + +fn hint_from_field(field: &ArrowField) -> IngestHint { + let meta = field.metadata(); + IngestHint { + shape: meta.get(SHAPE_KEY).and_then(|s| parse_shape(s)), + primitive: meta.get(PRIMITIVE_KEY).and_then(|s| parse_primitive(s)), + attribute: meta.get(ATTRIBUTE_KEY).and_then(|s| parse_attribute(s)), + sorted: meta.get(SORTED_KEY).and_then(|s| s.parse::().ok()), + } +} + +fn parse_shape(s: &str) -> Option { + match s { + "atom" => Some(IngestShape::Atom), + "vector" => Some(IngestShape::Vector), + "list" => Some(IngestShape::List), + "dictionary" => Some(IngestShape::Dictionary), + "table" => Some(IngestShape::Table), + "unary_primitive" => Some(IngestShape::UnaryPrimitive), + _ => None, + } +} + +fn parse_primitive(s: &str) -> Option { + match s { + "boolean" => Some(IngestPrimitive::Boolean), + "guid" => Some(IngestPrimitive::Guid), + "byte" => Some(IngestPrimitive::Byte), + "short" => Some(IngestPrimitive::Short), + "int" => Some(IngestPrimitive::Int), + "long" => Some(IngestPrimitive::Long), + "real" => Some(IngestPrimitive::Real), + "float" => Some(IngestPrimitive::Float), + "char" => Some(IngestPrimitive::Char), + "symbol" => Some(IngestPrimitive::Symbol), + "timestamp" => Some(IngestPrimitive::Timestamp), + "month" => Some(IngestPrimitive::Month), + "date" => Some(IngestPrimitive::Date), + "datetime" => Some(IngestPrimitive::Datetime), + "timespan" => Some(IngestPrimitive::Timespan), + "minute" => Some(IngestPrimitive::Minute), + "second" => Some(IngestPrimitive::Second), + "time" => Some(IngestPrimitive::Time), + _ => None, + } +} + +fn parse_attribute(s: &str) -> Option { + match s { + "none" => Some(Attribute::None), + "sorted" => Some(Attribute::Sorted), + "unique" => Some(Attribute::Unique), + "parted" => Some(Attribute::Parted), + "grouped" => Some(Attribute::Grouped), + _ => None, + } +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Convert an Arrow array + field descriptor into a q `Value`. +pub fn ingest_array(array: ArrayRef, field: &ArrowField) -> IngestionResult { + let hint = hint_from_field(field); + ingest_with_hint(array, hint) +} + +/// Convert an Arrow record batch into a q table `Value`. +pub fn ingest_record_batch(batch: RecordBatch) -> IngestionResult { + let schema = batch.schema(); + let mut column_names = Vec::with_capacity(batch.num_columns()); + let mut columns = Vec::with_capacity(batch.num_columns()); + for (index, field) in schema.fields().iter().enumerate() { + column_names.push(bytes::Bytes::copy_from_slice(field.name().as_bytes())); + columns.push(ingest_array(batch.column(index).clone(), field.as_ref())?); + } + let table = Table::new(Attribute::None, column_names, columns); + table + .validate() + .map_err(|e| IngestionError::Unsupported(e.to_string()))?; + Ok(Value::Table(table)) +} + +/// Convert a sequence of record batches (a stream) into a q table `Value`. +/// +/// All batches must share the same schema. The batches are concatenated using +/// `arrow_select::concat::concat_batches` before ingestion. +pub fn ingest_record_batch_reader( + schema: SchemaRef, + batches: impl IntoIterator>, +) -> IngestionResult { + let batches: Vec = batches.into_iter().collect::>()?; + if batches.is_empty() { + // Produce an empty table with the correct schema. + let column_names: Vec = schema + .fields() + .iter() + .map(|f| bytes::Bytes::copy_from_slice(f.name().as_bytes())) + .collect(); + let columns: Vec = schema + .fields() + .iter() + .map(|f| ingest_array(arrow_array::new_empty_array(f.data_type()), f.as_ref())) + .collect::>()?; + let table = Table::new(Attribute::None, column_names, columns); + return Ok(Value::Table(table)); + } + let merged = arrow_select::concat::concat_batches(&schema, &batches)?; + ingest_record_batch(merged) +} + +// --------------------------------------------------------------------------- +// Main dispatch +// --------------------------------------------------------------------------- + +fn ingest_with_hint(array: ArrayRef, hint: IngestHint) -> IngestionResult { + let shape = hint + .shape + .unwrap_or_else(|| default_shape(array.data_type())); + + match shape { + IngestShape::UnaryPrimitive => { + if array.len() != 1 { + return Err(IngestionError::Unsupported(format!( + "unary_primitive shape requires length 1, got {}", + array.len() + ))); + } + Ok(Value::UnaryPrimitive { opcode: -128 }) + } + IngestShape::Table => ingest_table(array, hint), + IngestShape::Dictionary => ingest_dictionary(array, hint), + IngestShape::List => ingest_list(array, hint), + IngestShape::Atom | IngestShape::Vector => ingest_scalar_or_vector(array, shape, hint), + } +} + +fn default_shape(dt: &DataType) -> IngestShape { + match dt { + DataType::Null => IngestShape::List, + DataType::List(_) | DataType::LargeList(_) => IngestShape::List, + // Multiple binary blobs default to a list of char vectors. + // Use explicit metadata (qroissant.shape=vector) for char vector. + DataType::Binary | DataType::LargeBinary | DataType::BinaryView => IngestShape::List, + DataType::Map(_, _) => IngestShape::Dictionary, + DataType::Struct(_) => IngestShape::Table, + _ => IngestShape::Vector, + } +} + +// --------------------------------------------------------------------------- +// Table ingestion (Struct array) +// --------------------------------------------------------------------------- + +fn ingest_table(array: ArrayRef, hint: IngestHint) -> IngestionResult { + let attribute = hint.attribute.unwrap_or(Attribute::None); + + let struct_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + IngestionError::Unsupported(format!( + "q table ingestion requires a StructArray, found {}", + array.data_type() + )) + })?; + + let fields = match array.data_type() { + DataType::Struct(fields) => fields.clone(), + other => { + return Err(IngestionError::Unsupported(format!( + "q table ingestion requires a struct field, found {other}" + ))); + } + }; + + let mut column_names = Vec::with_capacity(fields.len()); + let mut columns = Vec::with_capacity(fields.len()); + for (i, child_field) in fields.iter().enumerate() { + column_names.push(bytes::Bytes::copy_from_slice(child_field.name().as_bytes())); + columns.push(ingest_array( + struct_array.column(i).clone(), + child_field.as_ref(), + )?); + } + + let table = Table::new(attribute, column_names, columns); + table + .validate() + .map_err(|e| IngestionError::Unsupported(e.to_string()))?; + Ok(Value::Table(table)) +} + +// --------------------------------------------------------------------------- +// Dictionary ingestion (Map array) +// --------------------------------------------------------------------------- + +fn ingest_dictionary(array: ArrayRef, hint: IngestHint) -> IngestionResult { + let map_array = array.as_any().downcast_ref::().ok_or_else(|| { + IngestionError::Unsupported(format!( + "q dictionary ingestion requires a MapArray, found {}", + array.data_type() + )) + })?; + + if map_array.len() != 1 || map_array.is_null(0) { + return Err(IngestionError::Unsupported( + "q dictionary ingestion requires a non-null length-1 Arrow map".to_string(), + )); + } + + let entries = map_array.value(0); + let sorted = hint.sorted.unwrap_or(false); + let entry_fields = entries.fields().clone(); + let keys = ingest_array(entries.column(0).clone(), entry_fields[0].as_ref())?; + let values = ingest_array(entries.column(1).clone(), entry_fields[1].as_ref())?; + let dict = Dictionary::new(sorted, keys, values); + dict.validate() + .map_err(|e| IngestionError::Unsupported(e.to_string()))?; + Ok(Value::Dictionary(dict)) +} + +// --------------------------------------------------------------------------- +// List ingestion (List / LargeList / Binary / BinaryView arrays) +// --------------------------------------------------------------------------- + +fn ingest_list(array: ArrayRef, hint: IngestHint) -> IngestionResult { + let attribute = hint.attribute.unwrap_or(Attribute::None); + + match array.data_type() { + DataType::Null => { + let values = (0..array.len()) + .map(|_| Value::UnaryPrimitive { opcode: -128 }) + .collect(); + Ok(Value::List(List::new(attribute, values))) + } + DataType::List(child_field) => { + let child_field = child_field.clone(); + let list_array = array + .as_any() + .downcast_ref::() + .expect("List datatype must match ListArray"); + let mut values = Vec::with_capacity(list_array.len()); + for i in 0..list_array.len() { + let child = list_array.value(i); + values.push(ingest_array(child, child_field.as_ref())?); + } + Ok(Value::List(List::new(attribute, values))) + } + DataType::LargeList(child_field) => { + let child_field = child_field.clone(); + let list_array = array + .as_any() + .downcast_ref::() + .expect("LargeList datatype must match LargeListArray"); + let mut values = Vec::with_capacity(list_array.len()); + for i in 0..list_array.len() { + let child = list_array.value(i); + values.push(ingest_array(child, child_field.as_ref())?); + } + Ok(Value::List(List::new(attribute, values))) + } + DataType::Binary => { + let binary = array + .as_any() + .downcast_ref::() + .expect("Binary datatype must match BinaryArray"); + let values = (0..binary.len()) + .map(|i| { + Value::Vector(Vector::new( + Attribute::None, + VectorData::Char(bytes::Bytes::copy_from_slice(binary.value(i))), + )) + }) + .collect(); + Ok(Value::List(List::new(attribute, values))) + } + DataType::LargeBinary => { + let binary = array + .as_any() + .downcast_ref::() + .expect("LargeBinary datatype must match LargeBinaryArray"); + let values = (0..binary.len()) + .map(|i| { + Value::Vector(Vector::new( + Attribute::None, + VectorData::Char(bytes::Bytes::copy_from_slice(binary.value(i))), + )) + }) + .collect(); + Ok(Value::List(List::new(attribute, values))) + } + DataType::BinaryView => { + let binary = array + .as_any() + .downcast_ref::() + .expect("BinaryView datatype must match BinaryViewArray"); + let values = (0..binary.len()) + .map(|i| { + Value::Vector(Vector::new( + Attribute::None, + VectorData::Char(bytes::Bytes::copy_from_slice(binary.value(i))), + )) + }) + .collect(); + Ok(Value::List(List::new(attribute, values))) + } + other => Err(IngestionError::Unsupported(format!( + "q list ingestion from Arrow data type {other} is not supported" + ))), + } +} + +// --------------------------------------------------------------------------- +// Scalar / vector ingestion +// --------------------------------------------------------------------------- + +fn ingest_scalar_or_vector( + array: ArrayRef, + shape: IngestShape, + hint: IngestHint, +) -> IngestionResult { + let attribute = hint.attribute.unwrap_or(Attribute::None); + let is_atom = shape == IngestShape::Atom; + + if is_atom && array.len() != 1 { + return Err(IngestionError::Unsupported(format!( + "q atom shape requested but Arrow array has length {}", + array.len() + ))); + } + + match array.data_type() { + DataType::Boolean => ingest_boolean(&array, is_atom, attribute), + DataType::UInt8 => { + let prim = hint.primitive.unwrap_or(IngestPrimitive::Byte); + ingest_u8(&array, prim, is_atom, attribute) + } + DataType::Int16 => ingest_i16(&array, is_atom, attribute), + DataType::Int32 => { + let prim = hint.primitive.unwrap_or(IngestPrimitive::Int); + ingest_i32(&array, prim, is_atom, attribute) + } + DataType::Int64 => ingest_i64(&array, is_atom, attribute), + DataType::Float32 => ingest_f32(&array, is_atom, attribute), + DataType::Float64 => { + let prim = hint.primitive.unwrap_or(IngestPrimitive::Float); + ingest_f64(&array, prim, is_atom, attribute) + } + DataType::FixedSizeBinary(1) => { + let prim = hint.primitive.unwrap_or(IngestPrimitive::Char); + ingest_fixed_binary_1(&array, prim, is_atom, attribute) + } + DataType::FixedSizeBinary(16) => ingest_fixed_binary_16(&array, is_atom, attribute), + DataType::Utf8 => ingest_symbols_utf8(&array, is_atom, attribute), + DataType::LargeUtf8 => ingest_symbols_large_utf8(&array, is_atom, attribute), + DataType::Utf8View => ingest_symbols_utf8_view(&array, is_atom, attribute), + DataType::Dictionary(_, _) => ingest_symbols_dictionary(&array, is_atom, attribute), + DataType::Binary => ingest_binary_as_char(&array, is_atom, attribute), + DataType::LargeBinary => ingest_large_binary_as_char(&array, is_atom, attribute), + DataType::BinaryView => ingest_binary_view_as_char(&array, is_atom, attribute), + DataType::Date32 => ingest_date32(&array, is_atom, attribute), + DataType::Timestamp(TimeUnit::Nanosecond, tz) => { + if tz.is_some() { + return Err(IngestionError::Unsupported( + "Arrow timestamps with timezone cannot be ingested into q".to_string(), + )); + } + ingest_timestamp_ns(&array, is_atom, attribute) + } + DataType::Timestamp(TimeUnit::Microsecond, tz) => { + if tz.is_some() { + return Err(IngestionError::Unsupported( + "Arrow timestamps with timezone cannot be ingested into q".to_string(), + )); + } + ingest_timestamp_us(&array, is_atom, attribute) + } + DataType::Timestamp(TimeUnit::Millisecond, tz) => { + if tz.is_some() { + return Err(IngestionError::Unsupported( + "Arrow timestamps with timezone cannot be ingested into q".to_string(), + )); + } + ingest_timestamp_ms(&array, is_atom, attribute) + } + DataType::Timestamp(TimeUnit::Second, tz) => { + if tz.is_some() { + return Err(IngestionError::Unsupported( + "Arrow timestamps with timezone cannot be ingested into q".to_string(), + )); + } + ingest_timestamp_s(&array, is_atom, attribute) + } + DataType::Duration(TimeUnit::Nanosecond) => ingest_duration_ns(&array, is_atom, attribute), + DataType::Duration(TimeUnit::Microsecond) => ingest_duration_us(&array, is_atom, attribute), + DataType::Duration(TimeUnit::Millisecond) => ingest_duration_ms(&array, is_atom, attribute), + DataType::Duration(TimeUnit::Second) => ingest_duration_s(&array, is_atom, attribute), + DataType::Time32(TimeUnit::Second) => { + let prim = hint.primitive.unwrap_or(IngestPrimitive::Second); + ingest_time32_second(&array, prim, is_atom, attribute) + } + DataType::Time32(TimeUnit::Millisecond) => ingest_time32_ms(&array, is_atom, attribute), + DataType::Time64(TimeUnit::Microsecond) => ingest_time64_us(&array, is_atom, attribute), + DataType::Time64(TimeUnit::Nanosecond) => ingest_time64_ns(&array, is_atom, attribute), + other => Err(IngestionError::Unsupported(format!( + "q ingestion from Arrow data type {other} is not supported" + ))), + } +} + +// --------------------------------------------------------------------------- +// Boolean +// --------------------------------------------------------------------------- + +fn ingest_boolean(array: &ArrayRef, is_atom: bool, attribute: Attribute) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Boolean datatype must match BooleanArray"); + + if arr.null_count() != 0 { + return Err(IngestionError::Unsupported( + "Arrow boolean arrays with nulls cannot be ingested as q boolean vectors; \ + use a general list shape instead" + .to_string(), + )); + } + + let values: Vec = (0..arr.len()) + .map(|i| if arr.value(i) { 1 } else { 0 }) + .collect(); + if is_atom { + Ok(Value::Atom(Atom::Boolean(values[0] != 0))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Boolean(bytes::Bytes::from(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// UInt8 (Byte or Char) +// --------------------------------------------------------------------------- + +fn ingest_u8( + array: &ArrayRef, + prim: IngestPrimitive, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("UInt8 datatype must match UInt8Array"); + + if arr.null_count() != 0 { + return Err(IngestionError::Unsupported( + "Arrow UInt8 arrays with nulls cannot be ingested as q byte/char".to_string(), + )); + } + + let values: Vec = arr.values().to_vec(); + match prim { + IngestPrimitive::Char => { + if is_atom { + Ok(Value::Atom(Atom::Char(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Char(bytes::Bytes::from(values)), + ))) + } + } + _ => { + if is_atom { + Ok(Value::Atom(Atom::Byte(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Byte(bytes::Bytes::from(values)), + ))) + } + } + } +} + +// --------------------------------------------------------------------------- +// Int16 (Short) +// --------------------------------------------------------------------------- + +fn ingest_i16(array: &ArrayRef, is_atom: bool, attribute: Attribute) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Int16 datatype must match Int16Array"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_SHORT; + } + } + } + + if is_atom { + Ok(Value::Atom(Atom::Short(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Short(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Int32 (Int, Month, Date, Minute, Second, Time) +// --------------------------------------------------------------------------- + +fn ingest_i32( + array: &ArrayRef, + prim: IngestPrimitive, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Int32 datatype must match Int32Array"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + let null_sentinel = i32::MIN; // all i32 q types share i32::MIN as null + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = null_sentinel; + } + } + } + + if is_atom { + let v = values[0]; + let atom = match prim { + IngestPrimitive::Month => Atom::Month(v), + IngestPrimitive::Date => Atom::Date(v), + IngestPrimitive::Minute => Atom::Minute(v), + IngestPrimitive::Second => Atom::Second(v), + IngestPrimitive::Time => Atom::Time(v), + _ => Atom::Int(v), + }; + Ok(Value::Atom(atom)) + } else { + let bytes = vec_to_bytes(values); + let data = match prim { + IngestPrimitive::Month => VectorData::Month(bytes), + IngestPrimitive::Date => VectorData::Date(bytes), + IngestPrimitive::Minute => VectorData::Minute(bytes), + IngestPrimitive::Second => VectorData::Second(bytes), + IngestPrimitive::Time => VectorData::Time(bytes), + _ => VectorData::Int(bytes), + }; + Ok(Value::Vector(Vector::new(attribute, data))) + } +} + +// --------------------------------------------------------------------------- +// Int64 (Long) +// --------------------------------------------------------------------------- + +fn ingest_i64(array: &ArrayRef, is_atom: bool, attribute: Attribute) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Int64 datatype must match Int64Array"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = i64::MIN; + } + } + } + + if is_atom { + Ok(Value::Atom(Atom::Long(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Long(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Float32 (Real) +// --------------------------------------------------------------------------- + +fn ingest_f32(array: &ArrayRef, is_atom: bool, attribute: Attribute) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Float32 datatype must match Float32Array"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = f32::NAN; + } + } + } + + if is_atom { + Ok(Value::Atom(Atom::Real(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Real(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Float64 (Float, Datetime) +// --------------------------------------------------------------------------- + +fn ingest_f64( + array: &ArrayRef, + prim: IngestPrimitive, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Float64 datatype must match Float64Array"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = f64::NAN; + } + } + } + + if is_atom { + let v = values[0]; + let atom = match prim { + IngestPrimitive::Datetime => Atom::Datetime(v), + _ => Atom::Float(v), + }; + Ok(Value::Atom(atom)) + } else { + let bytes = vec_to_bytes(values); + let data = match prim { + IngestPrimitive::Datetime => VectorData::Datetime(bytes), + _ => VectorData::Float(bytes), + }; + Ok(Value::Vector(Vector::new(attribute, data))) + } +} + +// --------------------------------------------------------------------------- +// FixedSizeBinary(1) – Char or Byte +// --------------------------------------------------------------------------- + +fn ingest_fixed_binary_1( + array: &ArrayRef, + prim: IngestPrimitive, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("FixedSizeBinary(1) datatype must match FixedSizeBinaryArray"); + + if arr.null_count() != 0 { + return Err(IngestionError::Unsupported( + "Arrow FixedSizeBinary(1) arrays with nulls cannot be ingested as q char/byte" + .to_string(), + )); + } + + let values: Vec = (0..arr.len()).map(|i| arr.value(i)[0]).collect(); + match prim { + IngestPrimitive::Byte => { + if is_atom { + Ok(Value::Atom(Atom::Byte(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Byte(bytes::Bytes::from(values)), + ))) + } + } + _ => { + if is_atom { + Ok(Value::Atom(Atom::Char(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Char(bytes::Bytes::from(values)), + ))) + } + } + } +} + +// --------------------------------------------------------------------------- +// FixedSizeBinary(16) – Guid +// --------------------------------------------------------------------------- + +fn ingest_fixed_binary_16( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("FixedSizeBinary(16) datatype must match FixedSizeBinaryArray"); + + if arr.null_count() != 0 { + return Err(IngestionError::Unsupported( + "Arrow FixedSizeBinary(16) arrays with nulls cannot be ingested as q guid".to_string(), + )); + } + + let values: Vec<[u8; 16]> = (0..arr.len()) + .map(|i| { + let mut buf = [0u8; 16]; + buf.copy_from_slice(arr.value(i)); + buf + }) + .collect(); + + if is_atom { + Ok(Value::Atom(Atom::Guid(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::from_guids(&values), + ))) + } +} + +// --------------------------------------------------------------------------- +// Symbol (various string types) +// --------------------------------------------------------------------------- + +fn strings_to_symbol_value(strings: Vec>, is_atom: bool, attribute: Attribute) -> Value { + if is_atom { + Value::Atom(Atom::Symbol(bytes::Bytes::from( + strings.into_iter().next().unwrap_or_default(), + ))) + } else { + Value::Vector(Vector::new( + attribute, + VectorData::Symbol(strings.into_iter().map(bytes::Bytes::from).collect()), + )) + } +} + +fn ingest_symbols_utf8( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Utf8 datatype must match StringArray"); + let values: Vec> = (0..arr.len()) + .map(|i| { + if arr.is_null(i) { + vec![] + } else { + arr.value(i).as_bytes().to_vec() + } + }) + .collect(); + Ok(strings_to_symbol_value(values, is_atom, attribute)) +} + +fn ingest_symbols_large_utf8( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("LargeUtf8 datatype must match LargeStringArray"); + let values: Vec> = (0..arr.len()) + .map(|i| { + if arr.is_null(i) { + vec![] + } else { + arr.value(i).as_bytes().to_vec() + } + }) + .collect(); + Ok(strings_to_symbol_value(values, is_atom, attribute)) +} + +fn ingest_symbols_utf8_view( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Utf8View datatype must match StringViewArray"); + let values: Vec> = (0..arr.len()) + .map(|i| { + if arr.is_null(i) { + vec![] + } else { + arr.value(i).as_bytes().to_vec() + } + }) + .collect(); + Ok(strings_to_symbol_value(values, is_atom, attribute)) +} + +fn ingest_symbols_dictionary( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + macro_rules! try_dict_type { + ($key_type:ty) => {{ + if let Some(dict) = array + .as_any() + .downcast_ref::>() + { + let values_arr = dict.values(); + let strings: Vec> = (0..dict.len()) + .map(|i| { + if dict.is_null(i) { + return vec![]; + } + let key_idx = dict.key(i).expect("non-null key must have value") as usize; + if let Some(s) = values_arr.as_any().downcast_ref::() { + s.value(key_idx).as_bytes().to_vec() + } else if let Some(s) = + values_arr.as_any().downcast_ref::() + { + s.value(key_idx).as_bytes().to_vec() + } else if let Some(s) = + values_arr.as_any().downcast_ref::() + { + s.value(key_idx).as_bytes().to_vec() + } else { + vec![] + } + }) + .collect(); + return Ok(strings_to_symbol_value(strings, is_atom, attribute)); + } + }}; + } + try_dict_type!(arrow_array::types::Int8Type); + try_dict_type!(arrow_array::types::Int16Type); + try_dict_type!(arrow_array::types::Int32Type); + try_dict_type!(arrow_array::types::Int64Type); + try_dict_type!(arrow_array::types::UInt8Type); + try_dict_type!(arrow_array::types::UInt16Type); + try_dict_type!(arrow_array::types::UInt32Type); + try_dict_type!(arrow_array::types::UInt64Type); + Err(IngestionError::Unsupported( + "Unsupported dictionary key type for symbol ingestion".to_string(), + )) +} + +// --------------------------------------------------------------------------- +// Binary → Char vector (single-element binary → char vector) +// --------------------------------------------------------------------------- + +fn ingest_binary_as_char( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Binary datatype must match BinaryArray"); + + if arr.null_count() != 0 { + return Err(IngestionError::Unsupported( + "Arrow Binary arrays with nulls cannot be ingested as q char vectors".to_string(), + )); + } + if arr.len() != 1 { + return Err(IngestionError::Unsupported( + "Multi-element Binary arrays should use List shape for q ingestion".to_string(), + )); + } + + let bytes = arr.value(0).to_vec(); + if is_atom && bytes.len() == 1 { + Ok(Value::Atom(Atom::Char(bytes[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Char(bytes::Bytes::from(bytes)), + ))) + } +} + +fn ingest_large_binary_as_char( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("LargeBinary datatype must match LargeBinaryArray"); + + if arr.null_count() != 0 { + return Err(IngestionError::Unsupported( + "Arrow LargeBinary arrays with nulls cannot be ingested as q char vectors".to_string(), + )); + } + if arr.len() != 1 { + return Err(IngestionError::Unsupported( + "Multi-element LargeBinary arrays should use List shape for q ingestion".to_string(), + )); + } + + let bytes = arr.value(0).to_vec(); + if is_atom && bytes.len() == 1 { + Ok(Value::Atom(Atom::Char(bytes[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Char(bytes::Bytes::from(bytes)), + ))) + } +} + +fn ingest_binary_view_as_char( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("BinaryView datatype must match BinaryViewArray"); + + if arr.null_count() != 0 { + return Err(IngestionError::Unsupported( + "Arrow BinaryView arrays with nulls cannot be ingested as q char vectors".to_string(), + )); + } + if arr.len() != 1 { + return Err(IngestionError::Unsupported( + "Multi-element BinaryView arrays should use List shape for q ingestion".to_string(), + )); + } + + let bytes = arr.value(0).to_vec(); + if is_atom && bytes.len() == 1 { + Ok(Value::Atom(Atom::Char(bytes[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Char(bytes::Bytes::from(bytes)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Date32 → q Date (days since 2000-01-01) +// --------------------------------------------------------------------------- + +fn ingest_date32(array: &ArrayRef, is_atom: bool, attribute: Attribute) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Date32 datatype must match Date32Array"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_DATE; + } + } + } + for v in &mut values { + if *v != Q_NULL_DATE { + *v = v.saturating_sub(DATE_OFFSET_DAYS); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Date(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Date(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Timestamp → q Timestamp (ns since 2000-01-01) +// --------------------------------------------------------------------------- + +fn ingest_timestamp_ns( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Timestamp(Nanosecond) must match TimestampNanosecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESTAMP; + } + } + } + for v in &mut values { + if *v != Q_NULL_TIMESTAMP { + *v = v.saturating_sub(TIMESTAMP_OFFSET_NS); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timestamp(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timestamp(vec_to_bytes(values)), + ))) + } +} + +fn ingest_timestamp_us( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Timestamp(Microsecond) must match TimestampMicrosecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESTAMP; + } + } + } + for v in &mut values { + if *v != Q_NULL_TIMESTAMP { + *v = v.saturating_mul(1_000).saturating_sub(TIMESTAMP_OFFSET_NS); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timestamp(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timestamp(vec_to_bytes(values)), + ))) + } +} + +fn ingest_timestamp_ms( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Timestamp(Millisecond) must match TimestampMillisecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESTAMP; + } + } + } + for v in &mut values { + if *v != Q_NULL_TIMESTAMP { + *v = v + .saturating_mul(1_000_000) + .saturating_sub(TIMESTAMP_OFFSET_NS); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timestamp(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timestamp(vec_to_bytes(values)), + ))) + } +} + +fn ingest_timestamp_s( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Timestamp(Second) must match TimestampSecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESTAMP; + } + } + } + for v in &mut values { + if *v != Q_NULL_TIMESTAMP { + *v = v + .saturating_mul(1_000_000_000) + .saturating_sub(TIMESTAMP_OFFSET_NS); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timestamp(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timestamp(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Duration → q Timespan (ns) +// --------------------------------------------------------------------------- + +fn ingest_duration_ns( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Duration(Nanosecond) must match DurationNanosecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESPAN; + } + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timespan(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timespan(vec_to_bytes(values)), + ))) + } +} + +fn ingest_duration_us( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Duration(Microsecond) must match DurationMicrosecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESPAN; + } + } + } + for v in &mut values { + if *v != Q_NULL_TIMESPAN { + *v = v.saturating_mul(1_000); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timespan(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timespan(vec_to_bytes(values)), + ))) + } +} + +fn ingest_duration_ms( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Duration(Millisecond) must match DurationMillisecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESPAN; + } + } + } + for v in &mut values { + if *v != Q_NULL_TIMESPAN { + *v = v.saturating_mul(1_000_000); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timespan(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timespan(vec_to_bytes(values)), + ))) + } +} + +fn ingest_duration_s( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Duration(Second) must match DurationSecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIMESPAN; + } + } + } + for v in &mut values { + if *v != Q_NULL_TIMESPAN { + *v = v.saturating_mul(1_000_000_000); + } + } + + if is_atom { + Ok(Value::Atom(Atom::Timespan(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Timespan(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Time32(Second) → q Second or Minute +// --------------------------------------------------------------------------- + +fn ingest_time32_second( + array: &ArrayRef, + prim: IngestPrimitive, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Time32(Second) must match Time32SecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if prim == IngestPrimitive::Minute { + let null = Q_NULL_MINUTE; + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = null; + } + } + } + for v in &mut values { + if *v != null { + *v /= 60; + } + } + if is_atom { + Ok(Value::Atom(Atom::Minute(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Minute(vec_to_bytes(values)), + ))) + } + } else { + let null = Q_NULL_SECOND; + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = null; + } + } + } + if is_atom { + Ok(Value::Atom(Atom::Second(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Second(vec_to_bytes(values)), + ))) + } + } +} + +// --------------------------------------------------------------------------- +// Time32(Millisecond) → q Time (ms) +// --------------------------------------------------------------------------- + +fn ingest_time32_ms( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Time32(Millisecond) must match Time32MillisecondArray"); + + let mut values: Vec = arr.values().to_vec(); + if arr.null_count() != 0 { + for i in 0..arr.len() { + if arr.is_null(i) { + values[i] = Q_NULL_TIME; + } + } + } + + if is_atom { + Ok(Value::Atom(Atom::Time(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Time(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Time64(Microsecond) → q Time (ms, truncating) +// --------------------------------------------------------------------------- + +fn ingest_time64_us( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Time64(Microsecond) must match Time64MicrosecondArray"); + + let values: Vec = (0..arr.len()) + .map(|i| { + if arr.is_null(i) { + Q_NULL_TIME + } else { + (arr.value(i) / 1_000).clamp(i64::from(i32::MIN), i64::from(i32::MAX)) as i32 + } + }) + .collect(); + + if is_atom { + Ok(Value::Atom(Atom::Time(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Time(vec_to_bytes(values)), + ))) + } +} + +// --------------------------------------------------------------------------- +// Time64(Nanosecond) → q Time (ms, truncating) +// --------------------------------------------------------------------------- + +fn ingest_time64_ns( + array: &ArrayRef, + is_atom: bool, + attribute: Attribute, +) -> IngestionResult { + let arr = array + .as_any() + .downcast_ref::() + .expect("Time64(Nanosecond) must match Time64NanosecondArray"); + + let values: Vec = (0..arr.len()) + .map(|i| { + if arr.is_null(i) { + Q_NULL_TIME + } else { + (arr.value(i) / 1_000_000).clamp(i64::from(i32::MIN), i64::from(i32::MAX)) as i32 + } + }) + .collect(); + + if is_atom { + Ok(Value::Atom(Atom::Time(values[0]))) + } else { + Ok(Value::Vector(Vector::new( + attribute, + VectorData::Time(vec_to_bytes(values)), + ))) + } +} diff --git a/crates/qroissant-arrow/src/lib.rs b/crates/qroissant-arrow/src/lib.rs new file mode 100644 index 0000000..d4e617a --- /dev/null +++ b/crates/qroissant-arrow/src/lib.rs @@ -0,0 +1,26 @@ +//! Arrow interop layer for qroissant. +//! +//! Converts decoded q `Value` trees (from `qroissant-core`) into Apache Arrow +//! arrays and record batches. PyO3 and PyCapsule handling live in +//! `qroissant-python`; this crate is intentionally free of Python dependencies. + +pub mod error; +pub mod ingestion; +pub mod metadata; +pub mod options; +pub mod projection; + +pub use error::IngestionError; +pub use ingestion::ingest_array; +pub use ingestion::ingest_record_batch; +pub use ingestion::ingest_record_batch_reader; +pub use options::ListProjection; +pub use options::ProjectionOptions; +pub use options::StringProjection; +pub use options::SymbolProjection; +pub use options::UnionMode; +pub use projection::ArrayExport; +pub use projection::BatchExport; +pub use projection::project; +pub use projection::project_table; +pub use qroissant_core::HEADER_LEN as QIPC_HEADER_LEN; diff --git a/crates/qroissant-arrow/src/metadata.rs b/crates/qroissant-arrow/src/metadata.rs new file mode 100644 index 0000000..41b7ed3 --- /dev/null +++ b/crates/qroissant-arrow/src/metadata.rs @@ -0,0 +1,90 @@ +//! Arrow field metadata helpers for preserving q type semantics. +//! +//! Every Arrow field produced by qroissant carries metadata that round-trips +//! the original q shape, primitive type, and attribute information so that +//! downstream consumers can reconstruct exact q semantics from an Arrow schema. + +use std::collections::HashMap; +use std::sync::Arc; + +use arrow_schema::Field; +use arrow_schema::FieldRef; +use qroissant_core::Attribute; +use qroissant_core::Primitive; + +pub const SHAPE_KEY: &str = "qroissant.shape"; +pub const PRIMITIVE_KEY: &str = "qroissant.primitive"; +pub const ATTRIBUTE_KEY: &str = "qroissant.attribute"; +pub const SORTED_KEY: &str = "qroissant.sorted"; + +fn shape_to_str(shape: &str) -> &str { + match shape { + "atom" => "atom", + "vector" => "vector", + "list" => "list", + "dictionary" => "dictionary", + "table" => "table", + "unary_primitive" => "unary_primitive", + other => other, + } +} + +fn primitive_str(p: Primitive) -> &'static str { + match p { + Primitive::Boolean => "boolean", + Primitive::Guid => "guid", + Primitive::Byte => "byte", + Primitive::Short => "short", + Primitive::Int => "int", + Primitive::Long => "long", + Primitive::Real => "real", + Primitive::Float => "float", + Primitive::Char => "char", + Primitive::Symbol => "symbol", + Primitive::Timestamp => "timestamp", + Primitive::Month => "month", + Primitive::Date => "date", + Primitive::Datetime => "datetime", + Primitive::Timespan => "timespan", + Primitive::Minute => "minute", + Primitive::Second => "second", + Primitive::Time => "time", + Primitive::Mixed => "mixed", + } +} + +fn attribute_str(a: Attribute) -> &'static str { + match a { + Attribute::None => "none", + Attribute::Sorted => "sorted", + Attribute::Unique => "unique", + Attribute::Parted => "parted", + Attribute::Grouped => "grouped", + } +} + +/// Build an Arrow field for a q atom or vector column. +/// +/// The field name is left empty (`""`); callers that embed the field in a +/// schema or struct field should rename it via [`arrow_schema::Field::with_name`]. +pub fn q_field( + data_type: arrow_schema::DataType, + nullable: bool, + shape: &str, + primitive: Option, + attribute: Option, + sorted: Option, +) -> FieldRef { + let mut meta = HashMap::new(); + meta.insert(SHAPE_KEY.to_string(), shape_to_str(shape).to_string()); + if let Some(p) = primitive { + meta.insert(PRIMITIVE_KEY.to_string(), primitive_str(p).to_string()); + } + if let Some(a) = attribute { + meta.insert(ATTRIBUTE_KEY.to_string(), attribute_str(a).to_string()); + } + if let Some(s) = sorted { + meta.insert(SORTED_KEY.to_string(), s.to_string()); + } + Arc::new(Field::new("", data_type, nullable).with_metadata(meta)) +} diff --git a/crates/qroissant-arrow/src/options.rs b/crates/qroissant-arrow/src/options.rs new file mode 100644 index 0000000..8938919 --- /dev/null +++ b/crates/qroissant-arrow/src/options.rs @@ -0,0 +1,85 @@ +//! Projection-level configuration for `qroissant-arrow`. +//! +//! This module is intentionally free of PyO3 so the arrow crate can remain +//! Python-agnostic. The Python crate converts `DecodeOptions` into a +//! `ProjectionOptions` at decode time and stores it alongside the value. + +/// How to project q symbol (byte-string) values into Arrow. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum SymbolProjection { + /// Arrow `Utf8` / `StringArray`. Lossily converts non-UTF-8 bytes. + #[default] + Utf8, + /// Arrow `LargeUtf8` / `LargeStringArray`. + LargeUtf8, + /// Arrow `Utf8View` / `StringViewArray`. + Utf8View, + /// Arrow `Dictionary`. + Dictionary, + /// Arrow `Binary` / `BinaryArray` — raw bytes, no UTF-8 coercion. + RawBytes, +} + +/// How to project q char-vector (byte string) values into Arrow. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum StringProjection { + /// Arrow `Utf8` / `StringArray` (best-effort UTF-8). + #[default] + Utf8, + /// Arrow `Binary` / `BinaryArray` — raw bytes. + Binary, +} + +/// Wrapper Arrow type used for homogeneous q list projection. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum ListProjection { + /// Arrow `List` / `ListArray`. + List, + /// Arrow `LargeList` / `LargeListArray`. + #[default] + LargeList, + /// Arrow `ListView` — not yet supported; falls back to `List`. + ListView, +} + +/// Union encoding for heterogeneous q list projection. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum UnionMode { + /// Arrow dense union (compact offsets). + #[default] + Dense, + /// Arrow sparse union (one slot per item per type). + Sparse, +} + +/// Combined projection options threaded through `project()` / `project_table()`. +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ProjectionOptions { + pub symbol: SymbolProjection, + pub string: StringProjection, + pub list: ListProjection, + pub union_mode: UnionMode, + /// When `true`, q infinity sentinels (e.g. `0Wi`, `0Wj`, `0w`) are mapped + /// to Arrow nulls alongside the standard null sentinels. Default: `false`. + pub treat_infinity_as_null: bool, + /// When `true` and the table has at least 4 columns, column projection + /// is performed in parallel using rayon. Default: `true`. + pub parallel: bool, + /// When `true`, symbol bytes are assumed to be valid UTF-8 and are + /// reinterpreted without validation or allocation. Default: `true`. + pub assume_symbol_utf8: bool, +} + +impl Default for ProjectionOptions { + fn default() -> Self { + Self { + symbol: SymbolProjection::default(), + string: StringProjection::default(), + list: ListProjection::default(), + union_mode: UnionMode::default(), + treat_infinity_as_null: false, + parallel: true, + assume_symbol_utf8: true, + } + } +} diff --git a/crates/qroissant-arrow/src/projection.rs b/crates/qroissant-arrow/src/projection.rs new file mode 100644 index 0000000..6219b9e --- /dev/null +++ b/crates/qroissant-arrow/src/projection.rs @@ -0,0 +1,1160 @@ +//! Arrow projection: converts decoded q `Value` trees into Arrow arrays. + +use std::mem::size_of; +use std::ptr::NonNull; +use std::sync::Arc; + +use arrow_array::ArrayRef; +use arrow_array::BinaryArray; +use arrow_array::BooleanArray; +use arrow_buffer::BooleanBuffer; +use arrow_array::Date32Array; +use arrow_array::DurationNanosecondArray; +use arrow_array::FixedSizeBinaryArray; +use arrow_array::Float32Array; +use arrow_array::Float64Array; +use arrow_array::Int16Array; +use arrow_array::Int32Array; +use arrow_array::Int64Array; +use arrow_array::LargeListArray; +use arrow_array::LargeStringArray; +use arrow_array::ListArray; +use arrow_array::RecordBatch; +use arrow_array::StringArray; +use arrow_array::StringViewArray; +use arrow_array::StructArray; +use arrow_array::Time32MillisecondArray; +use arrow_array::Time32SecondArray; +use arrow_array::TimestampMillisecondArray; +use arrow_array::TimestampNanosecondArray; +use arrow_array::UInt8Array; +use arrow_array::UnionArray; +use arrow_array::builder::StringDictionaryBuilder; +use arrow_array::types::Int32Type; +use arrow_buffer::Buffer; +use arrow_buffer::NullBuffer; +use arrow_buffer::OffsetBuffer; +use arrow_buffer::ScalarBuffer; +use arrow_schema::DataType; +use arrow_schema::Field; +use arrow_schema::FieldRef; +use arrow_schema::Fields; +use arrow_schema::Schema; +use arrow_schema::SchemaRef; +use arrow_schema::TimeUnit; +use arrow_schema::UnionFields; +use arrow_schema::UnionMode as ArrowUnionMode; +use chrono::Months; +use chrono::NaiveDate; +use qroissant_core::Atom; +use qroissant_core::Attribute; +use qroissant_core::Dictionary; +use qroissant_core::List; +use qroissant_core::Table; +use qroissant_core::Value; +use qroissant_core::Vector; +use qroissant_core::VectorData; +use qroissant_kernels::nulls::Q_INF_DATE; +use qroissant_kernels::nulls::Q_INF_INT; +use qroissant_kernels::nulls::Q_INF_LONG; +use qroissant_kernels::nulls::Q_INF_MINUTE; +use qroissant_kernels::nulls::Q_INF_SECOND; +use qroissant_kernels::nulls::Q_INF_SHORT; +use qroissant_kernels::nulls::Q_INF_TIME; +use qroissant_kernels::nulls::Q_INF_TIMESPAN; +use qroissant_kernels::nulls::Q_INF_TIMESTAMP; +use qroissant_kernels::nulls::Q_NINF_DATE; +use qroissant_kernels::nulls::Q_NINF_INT; +use qroissant_kernels::nulls::Q_NINF_LONG; +use qroissant_kernels::nulls::Q_NINF_MINUTE; +use qroissant_kernels::nulls::Q_NINF_SECOND; +use qroissant_kernels::nulls::Q_NINF_SHORT; +use qroissant_kernels::nulls::Q_NINF_TIME; +use qroissant_kernels::nulls::Q_NINF_TIMESPAN; +use qroissant_kernels::nulls::Q_NINF_TIMESTAMP; +use qroissant_kernels::nulls::Q_NULL_DATE; +use qroissant_kernels::nulls::Q_NULL_INT; +use qroissant_kernels::nulls::Q_NULL_LONG; +use qroissant_kernels::nulls::Q_NULL_MINUTE; +use qroissant_kernels::nulls::Q_NULL_MONTH; +use qroissant_kernels::nulls::Q_NULL_SECOND; +use qroissant_kernels::nulls::Q_NULL_SHORT; +use qroissant_kernels::nulls::Q_NULL_TIME; +use qroissant_kernels::nulls::Q_NULL_TIMESPAN; +use qroissant_kernels::nulls::Q_NULL_TIMESTAMP; +use qroissant_kernels::nulls::validity_f32; +use qroissant_kernels::nulls::validity_f64; +use qroissant_kernels::nulls::validity_i16; +use qroissant_kernels::nulls::validity_i32; +use qroissant_kernels::nulls::validity_i64; +use qroissant_kernels::temporal::DATE_OFFSET_DAYS; +use qroissant_kernels::temporal::MILLIS_PER_DAY; +use qroissant_kernels::temporal::TIMESTAMP_OFFSET_NS; +use qroissant_kernels::boolean::pack_bool_bytes; +use qroissant_kernels::temporal::copy_and_minutes_to_seconds; +use qroissant_kernels::temporal::copy_and_offset_dates; +use qroissant_kernels::temporal::copy_and_offset_timestamps; +use rayon::prelude::*; + +use crate::error::ProjectionError; +use crate::error::ProjectionResult; +use crate::metadata::q_field; +use crate::options::ListProjection; +use crate::options::ProjectionOptions; +use crate::options::StringProjection; +use crate::options::SymbolProjection; +use crate::options::UnionMode; + +const MILLENNIUM: NaiveDate = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); +const UNIX_EPOCH: NaiveDate = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + +// --------------------------------------------------------------------------- +// Public output types +// --------------------------------------------------------------------------- + +/// Projected Arrow array together with its annotated field descriptor. +pub struct ArrayExport { + pub array: ArrayRef, + pub field: FieldRef, +} + +/// Projected Arrow record batch together with the struct-typed wrapper needed +/// for the C stream interface. +pub struct BatchExport { + /// The schema of the record batch. + pub schema: SchemaRef, + /// The data as an owned RecordBatch. + pub batch: RecordBatch, + /// The StructArray version of `batch`, ready to feed into an Arrow stream. + pub struct_array: ArrayRef, + /// A struct field that wraps the schema so pyo3-arrow can create a stream + /// capsule without needing to inspect Arrow internals. + pub struct_field: FieldRef, +} + +// --------------------------------------------------------------------------- +// Entry points +// --------------------------------------------------------------------------- + +/// Project any `Value` (except Table) into an `ArrayExport`. +/// +/// For atoms the result is a length-1 array. For vectors the result is a +/// full-length array. Tables must be projected with [`project_table`]. +pub fn project(value: &Value, opts: &ProjectionOptions) -> ProjectionResult { + match value { + Value::Atom(atom) => project_atom(atom, opts), + Value::Vector(vector) => project_vector(vector, opts), + Value::List(list) => project_list(list, opts), + Value::Dictionary(dict) => project_dictionary(dict, opts), + Value::Table(table) => { + let batch = project_table(table, opts)?; + Ok(ArrayExport { + array: batch.struct_array, + field: batch.struct_field, + }) + } + Value::UnaryPrimitive { .. } => { + let array: ArrayRef = Arc::new(arrow_array::NullArray::new(1)); + let field = q_field(DataType::Null, true, "unary_primitive", None, None, None); + Ok(ArrayExport { array, field }) + } + } +} + +/// Project a q `Table` into a `BatchExport` suitable for the Arrow C stream +/// interface. +pub fn project_table(table: &Table, opts: &ProjectionOptions) -> ProjectionResult { + const PARALLEL_THRESHOLD: usize = 4; + let use_parallel = opts.parallel && table.num_columns() >= PARALLEL_THRESHOLD; + + let column_pairs: Vec<(&[u8], &qroissant_core::Value)> = table + .column_names() + .iter() + .map(|n| n.as_ref()) + .zip(table.columns()) + .collect(); + + let results: Vec> = if use_parallel { + column_pairs + .par_iter() + .map(|(name_bytes, column)| { + let name = String::from_utf8_lossy(name_bytes); + let export = project(column, opts)?; + let named_field = Arc::new(export.field.as_ref().clone().with_name(name.as_ref())); + Ok((named_field, export.array)) + }) + .collect() + } else { + column_pairs + .iter() + .map(|(name_bytes, column)| { + let name = String::from_utf8_lossy(name_bytes); + let export = project(column, opts)?; + let named_field = Arc::new(export.field.as_ref().clone().with_name(name.as_ref())); + Ok((named_field, export.array)) + }) + .collect() + }; + + let mut fields: Vec = Vec::with_capacity(table.num_columns()); + let mut arrays: Vec = Vec::with_capacity(table.num_columns()); + for result in results { + let (field, array) = result?; + fields.push(field); + arrays.push(array); + } + + // Build schema metadata marking this as a q table. + let mut meta = std::collections::HashMap::new(); + meta.insert(crate::metadata::SHAPE_KEY.to_string(), "table".to_string()); + if table.attribute() != Attribute::None { + meta.insert( + crate::metadata::ATTRIBUTE_KEY.to_string(), + attribute_meta_str(table.attribute()).to_string(), + ); + } + let schema = Arc::new(Schema::new(fields.clone()).with_metadata(meta)); + + let batch = RecordBatch::try_new(schema.clone(), arrays) + .map_err(|e| ProjectionError::Arrow(e.to_string()))?; + + // Wrap as StructArray for the C stream interface. + let struct_field: FieldRef = Arc::new( + Field::new_struct("", schema.fields().clone(), false) + .with_metadata(schema.metadata().clone()), + ); + let struct_array: ArrayRef = Arc::new(arrow_array::StructArray::from(batch.clone())); + + Ok(BatchExport { + schema, + batch, + struct_array, + struct_field, + }) +} + +// --------------------------------------------------------------------------- +// Atom projection +// --------------------------------------------------------------------------- + +fn project_atom(atom: &Atom, opts: &ProjectionOptions) -> ProjectionResult { + let primitive = atom.primitive(); + let (array, dt) = match atom { + Atom::Boolean(v) => { + let arr: ArrayRef = Arc::new(BooleanArray::from(vec![*v])); + (arr, DataType::Boolean) + } + Atom::Guid(v) => { + let arr: ArrayRef = Arc::new( + FixedSizeBinaryArray::try_from_iter(std::iter::once(v.as_slice())) + .map_err(|e| ProjectionError::Arrow(e.to_string()))?, + ); + (arr, DataType::FixedSizeBinary(16)) + } + Atom::Byte(v) => { + let arr: ArrayRef = Arc::new(UInt8Array::from(vec![*v])); + (arr, DataType::UInt8) + } + Atom::Short(v) => { + let opt = if *v == Q_NULL_SHORT { None } else { Some(*v) }; + let arr: ArrayRef = Arc::new(Int16Array::from(vec![opt])); + (arr, DataType::Int16) + } + Atom::Int(v) => { + let opt = if *v == Q_NULL_INT { None } else { Some(*v) }; + let arr: ArrayRef = Arc::new(Int32Array::from(vec![opt])); + (arr, DataType::Int32) + } + Atom::Long(v) => { + let opt = if *v == Q_NULL_LONG { None } else { Some(*v) }; + let arr: ArrayRef = Arc::new(Int64Array::from(vec![opt])); + (arr, DataType::Int64) + } + Atom::Real(v) => { + let opt = if v.is_nan() { None } else { Some(*v) }; + let arr: ArrayRef = Arc::new(Float32Array::from(vec![opt])); + (arr, DataType::Float32) + } + Atom::Float(v) => { + let opt = if v.is_nan() { None } else { Some(*v) }; + let arr: ArrayRef = Arc::new(Float64Array::from(vec![opt])); + (arr, DataType::Float64) + } + Atom::Char(v) => { + let (arr, dt) = project_char_atom(*v, opts.string); + return Ok(ArrayExport { + array: arr, + field: q_field(dt, false, "atom", Some(primitive), None, None), + }); + } + Atom::Symbol(v) => { + let (arr, dt) = project_symbol_bytes( + std::slice::from_ref(v), + opts.symbol, + opts.assume_symbol_utf8, + )?; + return Ok(ArrayExport { + array: arr, + field: q_field(dt, false, "atom", Some(primitive), None, None), + }); + } + Atom::Timestamp(v) => { + let opt = if *v == Q_NULL_TIMESTAMP { + None + } else { + Some(v.saturating_add(TIMESTAMP_OFFSET_NS)) + }; + let arr: ArrayRef = Arc::new(TimestampNanosecondArray::from(vec![opt])); + (arr, DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + Atom::Month(v) => { + let opt = q_month_to_date32(*v); + let arr: ArrayRef = Arc::new(Date32Array::from(vec![opt])); + (arr, DataType::Date32) + } + Atom::Date(v) => { + let opt = if *v == Q_NULL_DATE { + None + } else { + Some(v.saturating_add(DATE_OFFSET_DAYS)) + }; + let arr: ArrayRef = Arc::new(Date32Array::from(vec![opt])); + (arr, DataType::Date32) + } + Atom::Datetime(v) => { + let opt = if v.is_nan() { + None + } else { + Some(datetime_to_ms(*v)) + }; + let arr: ArrayRef = Arc::new(TimestampMillisecondArray::from(vec![opt])); + (arr, DataType::Timestamp(TimeUnit::Millisecond, None)) + } + Atom::Timespan(v) => { + let opt = if *v == Q_NULL_TIMESPAN { + None + } else { + Some(*v) + }; + let arr: ArrayRef = Arc::new(DurationNanosecondArray::from(vec![opt])); + (arr, DataType::Duration(TimeUnit::Nanosecond)) + } + Atom::Minute(v) => { + let opt = if *v == Q_NULL_MINUTE { + None + } else { + Some(v.saturating_mul(60)) + }; + let arr: ArrayRef = Arc::new(Time32SecondArray::from(vec![opt])); + (arr, DataType::Time32(TimeUnit::Second)) + } + Atom::Second(v) => { + let opt = if *v == Q_NULL_SECOND { None } else { Some(*v) }; + let arr: ArrayRef = Arc::new(Time32SecondArray::from(vec![opt])); + (arr, DataType::Time32(TimeUnit::Second)) + } + Atom::Time(v) => { + let opt = if *v == Q_NULL_TIME { None } else { Some(*v) }; + let arr: ArrayRef = Arc::new(Time32MillisecondArray::from(vec![opt])); + (arr, DataType::Time32(TimeUnit::Millisecond)) + } + }; + + let nullable = array.null_count() > 0; + let field = q_field(dt, nullable, "atom", Some(primitive), None, None); + Ok(ArrayExport { array, field }) +} + +// --------------------------------------------------------------------------- +// Vector projection +// --------------------------------------------------------------------------- + +/// Converts an optional SIMD validity vector into an Arrow `NullBuffer`. +fn to_null_buffer(validity: Option>) -> Option { + validity.map(|v| NullBuffer::from(v)) +} + +/// Creates an Arrow `Buffer` backed by a `bytes::Bytes` without copying. +/// +/// The `Bytes` reference count is bumped (via `Arc`) to keep the memory alive +/// as long as Arrow holds the buffer. When Arrow drops the buffer, it drops +/// the `Arc`, which decrements the refcount. +fn bytes_to_arrow_buffer(bytes: bytes::Bytes) -> Buffer { + let len = bytes.len(); + if len == 0 { + return Buffer::from_vec(Vec::::new()); + } + let ptr = NonNull::new(bytes.as_ptr() as *mut u8).expect("non-null Bytes pointer"); + // SAFETY: `ptr` is valid for `len` bytes, and the `Arc` moved + // into the allocation owner keeps the backing memory alive. + unsafe { Buffer::from_custom_allocation(ptr, len, Arc::new(bytes)) } +} + +/// Merges a null-only validity with additional infinity sentinels. +/// Used when `treat_infinity_as_null` is enabled. +fn merge_infinity_i16(values: &[i16], null_validity: Option>) -> Option> { + let has_inf = values + .iter() + .any(|&v| v == Q_INF_SHORT || v == Q_NINF_SHORT); + if !has_inf { + return null_validity; + } + let mut validity = null_validity.unwrap_or_else(|| vec![true; values.len()]); + for (i, &v) in values.iter().enumerate() { + if v == Q_INF_SHORT || v == Q_NINF_SHORT { + validity[i] = false; + } + } + Some(validity) +} + +fn merge_infinity_i32( + values: &[i32], + sentinel_inf: i32, + sentinel_ninf: i32, + null_validity: Option>, +) -> Option> { + let has_inf = values + .iter() + .any(|&v| v == sentinel_inf || v == sentinel_ninf); + if !has_inf { + return null_validity; + } + let mut validity = null_validity.unwrap_or_else(|| vec![true; values.len()]); + for (i, &v) in values.iter().enumerate() { + if v == sentinel_inf || v == sentinel_ninf { + validity[i] = false; + } + } + Some(validity) +} + +fn merge_infinity_i64( + values: &[i64], + sentinel_inf: i64, + sentinel_ninf: i64, + null_validity: Option>, +) -> Option> { + let has_inf = values + .iter() + .any(|&v| v == sentinel_inf || v == sentinel_ninf); + if !has_inf { + return null_validity; + } + let mut validity = null_validity.unwrap_or_else(|| vec![true; values.len()]); + for (i, &v) in values.iter().enumerate() { + if v == sentinel_inf || v == sentinel_ninf { + validity[i] = false; + } + } + Some(validity) +} + +fn merge_infinity_f32(values: &[f32], null_validity: Option>) -> Option> { + let has_inf = values.iter().any(|v| v.is_infinite()); + if !has_inf { + return null_validity; + } + let mut validity = null_validity.unwrap_or_else(|| vec![true; values.len()]); + for (i, v) in values.iter().enumerate() { + if v.is_infinite() { + validity[i] = false; + } + } + Some(validity) +} + +fn merge_infinity_f64(values: &[f64], null_validity: Option>) -> Option> { + let has_inf = values.iter().any(|v| v.is_infinite()); + if !has_inf { + return null_validity; + } + let mut validity = null_validity.unwrap_or_else(|| vec![true; values.len()]); + for (i, v) in values.iter().enumerate() { + if v.is_infinite() { + validity[i] = false; + } + } + Some(validity) +} + +fn project_vector(vector: &Vector, opts: &ProjectionOptions) -> ProjectionResult { + let attribute = vector.attribute(); + let primitive = vector.primitive(); + let attr_opt = if attribute == Attribute::None { + None + } else { + Some(attribute) + }; + + let data = vector.data(); + let (array, dt): (ArrayRef, DataType) = match data { + VectorData::Boolean(values) => { + let (bitmap, len) = pack_bool_bytes(values.as_ref()); + let buf = Buffer::from_vec(bitmap); + let arr: ArrayRef = + Arc::new(BooleanArray::new(BooleanBuffer::new(buf, 0, len), None)); + (arr, DataType::Boolean) + } + VectorData::Guid(raw) => { + let len = raw.len() / 16; + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(FixedSizeBinaryArray::new(16, buf, None)); + debug_assert_eq!(arr.len(), len); + (arr, DataType::FixedSizeBinary(16)) + } + VectorData::Byte(raw) => { + let len = raw.len(); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(UInt8Array::new(ScalarBuffer::new(buf, 0, len), None)); + (arr, DataType::UInt8) + } + VectorData::Short(raw) => { + let values = data.as_i16_slice(); + let mut validity = validity_i16(values); + if opts.treat_infinity_as_null { + validity = merge_infinity_i16(values, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(Int16Array::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Int16) + } + VectorData::Int(raw) => { + let values = data.as_i32_slice(); + let mut validity = validity_i32(values, Q_NULL_INT); + if opts.treat_infinity_as_null { + validity = merge_infinity_i32(values, Q_INF_INT, Q_NINF_INT, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(Int32Array::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Int32) + } + VectorData::Long(raw) => { + let values = data.as_i64_slice(); + let mut validity = validity_i64(values, Q_NULL_LONG); + if opts.treat_infinity_as_null { + validity = merge_infinity_i64(values, Q_INF_LONG, Q_NINF_LONG, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(Int64Array::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Int64) + } + VectorData::Real(raw) => { + let values = data.as_f32_slice(); + let mut validity = validity_f32(values); + if opts.treat_infinity_as_null { + validity = merge_infinity_f32(values, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(Float32Array::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Float32) + } + VectorData::Float(raw) => { + let values = data.as_f64_slice(); + let mut validity = validity_f64(values); + if opts.treat_infinity_as_null { + validity = merge_infinity_f64(values, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(Float64Array::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Float64) + } + VectorData::Char(values) => { + let (arr, dt) = project_char_vector(values, opts.string); + return Ok(ArrayExport { + array: arr, + field: q_field(dt, false, "vector", Some(primitive), attr_opt, None), + }); + } + VectorData::Symbol(values) => { + let (arr, dt) = project_symbol_bytes(values, opts.symbol, opts.assume_symbol_utf8)?; + return Ok(ArrayExport { + array: arr, + field: q_field(dt, false, "vector", Some(primitive), attr_opt, None), + }); + } + VectorData::Timestamp(_) => { + let src = data.as_i64_slice(); + let mut validity = validity_i64(src, Q_NULL_TIMESTAMP); + if opts.treat_infinity_as_null { + validity = merge_infinity_i64(src, Q_INF_TIMESTAMP, Q_NINF_TIMESTAMP, validity); + } + let nulls = to_null_buffer(validity); + let mut bytes_buf = vec![0u8; src.len() * size_of::()]; + copy_and_offset_timestamps(src, bytemuck::cast_slice_mut(&mut bytes_buf)); + let buf = Buffer::from_vec(bytes_buf); + let arr: ArrayRef = Arc::new(TimestampNanosecondArray::new( + ScalarBuffer::new(buf, 0, src.len()), + nulls, + )); + (arr, DataType::Timestamp(TimeUnit::Nanosecond, None)) + } + VectorData::Month(_) => { + let values = data.as_i32_slice(); + // Month conversion is non-linear (variable days per month), keep scalar path. + let arr: ArrayRef = Arc::new(Date32Array::from_iter( + values.iter().map(|&v| q_month_to_date32(v)), + )); + (arr, DataType::Date32) + } + VectorData::Date(_) => { + let src = data.as_i32_slice(); + let mut validity = validity_i32(src, Q_NULL_DATE); + if opts.treat_infinity_as_null { + validity = merge_infinity_i32(src, Q_INF_DATE, Q_NINF_DATE, validity); + } + let nulls = to_null_buffer(validity); + let mut bytes_buf = vec![0u8; src.len() * size_of::()]; + copy_and_offset_dates(src, bytemuck::cast_slice_mut(&mut bytes_buf)); + let buf = Buffer::from_vec(bytes_buf); + let arr: ArrayRef = + Arc::new(Date32Array::new(ScalarBuffer::new(buf, 0, src.len()), nulls)); + (arr, DataType::Date32) + } + VectorData::Datetime(_) => { + let values = data.as_f64_slice(); + // Datetime is f64 fractional days — non-trivial transform, keep scalar path. + let nulls = to_null_buffer(validity_f64(values)); + let ms_values: Vec = values + .iter() + .map(|&v| if v.is_nan() { 0 } else { datetime_to_ms(v) }) + .collect(); + let arr: ArrayRef = Arc::new(TimestampMillisecondArray::new( + ScalarBuffer::from(ms_values), + nulls, + )); + (arr, DataType::Timestamp(TimeUnit::Millisecond, None)) + } + VectorData::Timespan(raw) => { + let values = data.as_i64_slice(); + let mut validity = validity_i64(values, Q_NULL_TIMESPAN); + if opts.treat_infinity_as_null { + validity = merge_infinity_i64(values, Q_INF_TIMESPAN, Q_NINF_TIMESPAN, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(DurationNanosecondArray::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Duration(TimeUnit::Nanosecond)) + } + VectorData::Minute(_) => { + let src = data.as_i32_slice(); + let mut validity = validity_i32(src, Q_NULL_MINUTE); + if opts.treat_infinity_as_null { + validity = merge_infinity_i32(src, Q_INF_MINUTE, Q_NINF_MINUTE, validity); + } + let nulls = to_null_buffer(validity); + let mut bytes_buf = vec![0u8; src.len() * size_of::()]; + copy_and_minutes_to_seconds(src, bytemuck::cast_slice_mut(&mut bytes_buf)); + let buf = Buffer::from_vec(bytes_buf); + let arr: ArrayRef = Arc::new(Time32SecondArray::new( + ScalarBuffer::new(buf, 0, src.len()), + nulls, + )); + (arr, DataType::Time32(TimeUnit::Second)) + } + VectorData::Second(raw) => { + let values = data.as_i32_slice(); + let mut validity = validity_i32(values, Q_NULL_SECOND); + if opts.treat_infinity_as_null { + validity = merge_infinity_i32(values, Q_INF_SECOND, Q_NINF_SECOND, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(Time32SecondArray::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Time32(TimeUnit::Second)) + } + VectorData::Time(raw) => { + let values = data.as_i32_slice(); + let mut validity = validity_i32(values, Q_NULL_TIME); + if opts.treat_infinity_as_null { + validity = merge_infinity_i32(values, Q_INF_TIME, Q_NINF_TIME, validity); + } + let nulls = to_null_buffer(validity); + let buf = bytes_to_arrow_buffer(raw.clone()); + let arr: ArrayRef = Arc::new(Time32MillisecondArray::new( + ScalarBuffer::new(buf, 0, values.len()), + nulls, + )); + (arr, DataType::Time32(TimeUnit::Millisecond)) + } + }; + + let nullable = array.null_count() > 0; + let field = q_field(dt, nullable, "vector", Some(primitive), attr_opt, None); + Ok(ArrayExport { array, field }) +} + +// --------------------------------------------------------------------------- +// Symbol and char projection helpers +// --------------------------------------------------------------------------- + +/// Project a slice of q symbol byte-strings using the requested `SymbolProjection`. +fn project_symbol_bytes( + values: &[bytes::Bytes], + mode: SymbolProjection, + assume_utf8: bool, +) -> ProjectionResult<(ArrayRef, DataType)> { + match mode { + SymbolProjection::Utf8 => { + let arr: ArrayRef = if assume_utf8 { + Arc::new(StringArray::from_iter_values(values.iter().map(|v| { + // SAFETY: caller asserts symbols are valid UTF-8 (q symbols are ASCII). + unsafe { std::str::from_utf8_unchecked(v) } + }))) + } else { + Arc::new(StringArray::from_iter_values( + values + .iter() + .map(|v| String::from_utf8_lossy(v).into_owned()), + )) + }; + Ok((arr, DataType::Utf8)) + } + SymbolProjection::LargeUtf8 => { + let arr: ArrayRef = if assume_utf8 { + Arc::new(LargeStringArray::from_iter_values( + values + .iter() + .map(|v| unsafe { std::str::from_utf8_unchecked(v) }), + )) + } else { + Arc::new(LargeStringArray::from_iter_values( + values + .iter() + .map(|v| String::from_utf8_lossy(v).into_owned()), + )) + }; + Ok((arr, DataType::LargeUtf8)) + } + SymbolProjection::Utf8View => { + let arr: ArrayRef = if assume_utf8 { + Arc::new(StringViewArray::from_iter_values( + values + .iter() + .map(|v| unsafe { std::str::from_utf8_unchecked(v) }), + )) + } else { + Arc::new(StringViewArray::from_iter_values( + values + .iter() + .map(|v| String::from_utf8_lossy(v).into_owned()), + )) + }; + Ok((arr, DataType::Utf8View)) + } + SymbolProjection::Dictionary => { + let mut builder = StringDictionaryBuilder::::new(); + if assume_utf8 { + for v in values { + builder.append_value(unsafe { std::str::from_utf8_unchecked(v.as_ref()) }); + } + } else { + for v in values { + builder.append_value(String::from_utf8_lossy(v.as_ref())); + } + } + let arr: ArrayRef = Arc::new(builder.finish()); + Ok(( + arr, + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + )) + } + SymbolProjection::RawBytes => { + let arr: ArrayRef = Arc::new(BinaryArray::from_iter_values( + values.iter().map(|v| v.as_ref()), + )); + Ok((arr, DataType::Binary)) + } + } +} + +/// Project a single q char atom byte. +fn project_char_atom(v: u8, mode: StringProjection) -> (ArrayRef, DataType) { + match mode { + StringProjection::Utf8 => { + let s = std::str::from_utf8(std::slice::from_ref(&v)).unwrap_or("\u{FFFD}"); + let arr: ArrayRef = Arc::new(StringArray::from(vec![s])); + (arr, DataType::Utf8) + } + StringProjection::Binary => { + let arr: ArrayRef = Arc::new(BinaryArray::from_iter_values(std::iter::once( + std::slice::from_ref(&v), + ))); + (arr, DataType::Binary) + } + } +} + +/// Project a q char vector as a per-character string/binary array. +fn project_char_vector(values: &[u8], mode: StringProjection) -> (ArrayRef, DataType) { + match mode { + StringProjection::Utf8 => { + let arr: ArrayRef = Arc::new(StringArray::from_iter_values( + values.iter().enumerate().map(|(i, &b)| { + if b.is_ascii() { + // SAFETY: a single ASCII byte is valid UTF-8. + unsafe { std::str::from_utf8_unchecked(&values[i..i + 1]) } + } else { + "\u{FFFD}" + } + }), + )); + (arr, DataType::Utf8) + } + StringProjection::Binary => { + let arr: ArrayRef = Arc::new(BinaryArray::from_iter_values( + values.iter().map(std::slice::from_ref), + )); + (arr, DataType::Binary) + } + } +} + +// --------------------------------------------------------------------------- +// List projection +// --------------------------------------------------------------------------- + +fn project_list(list: &List, opts: &ProjectionOptions) -> ProjectionResult { + let values = list.values(); + if values.is_empty() { + // Empty list: LargeList or List depending on opts. + let child_field: FieldRef = Arc::new(Field::new("item", DataType::Null, true)); + return match opts.list { + ListProjection::List | ListProjection::ListView => { + let offsets = OffsetBuffer::::new(ScalarBuffer::from(vec![0i32])); + let child_array: ArrayRef = Arc::new(arrow_array::NullArray::new(0)); + let array: ArrayRef = Arc::new(ListArray::new( + child_field.clone(), + offsets, + child_array, + None, + )); + let field = q_field(DataType::List(child_field), false, "list", None, None, None); + Ok(ArrayExport { array, field }) + } + ListProjection::LargeList => { + let offsets = OffsetBuffer::::new(ScalarBuffer::from(vec![0i64])); + let child_array: ArrayRef = Arc::new(arrow_array::NullArray::new(0)); + let array: ArrayRef = Arc::new(LargeListArray::new( + child_field.clone(), + offsets, + child_array, + None, + )); + let field = q_field( + DataType::LargeList(child_field), + false, + "list", + None, + None, + None, + ); + Ok(ArrayExport { array, field }) + } + }; + } + + // Project each element; they may be heterogeneous. + let mut child_exports: Vec = Vec::with_capacity(values.len()); + for v in values { + child_exports.push(project(v, opts)?); + } + + // Check homogeneity: all child arrays must share the same Arrow DataType. + let first_dt = child_exports[0].array.data_type().clone(); + let homogeneous = child_exports + .iter() + .all(|e| e.array.data_type() == &first_dt); + + if !homogeneous { + return project_heterogeneous_list(values, child_exports, opts); + } + + // Concatenate child arrays. + let refs: Vec<&dyn arrow_array::Array> = + child_exports.iter().map(|e| e.array.as_ref()).collect(); + let concatenated = + arrow_select::concat::concat(&refs).map_err(|e| ProjectionError::Arrow(e.to_string()))?; + + let child_field: FieldRef = Arc::new(Field::new("item", first_dt.clone(), true)); + let attr_opt = if list.attribute() == Attribute::None { + None + } else { + Some(list.attribute()) + }; + + match opts.list { + ListProjection::List | ListProjection::ListView => { + // Build i32 offsets. + let mut offsets: Vec = Vec::with_capacity(values.len() + 1); + offsets.push(0); + for e in &child_exports { + offsets.push(*offsets.last().unwrap() + e.array.len() as i32); + } + let offset_buf = OffsetBuffer::::new(ScalarBuffer::from(offsets)); + let array: ArrayRef = Arc::new(ListArray::new( + child_field.clone(), + offset_buf, + concatenated, + None, + )); + let field = q_field( + DataType::List(child_field), + false, + "list", + None, + attr_opt, + None, + ); + Ok(ArrayExport { array, field }) + } + ListProjection::LargeList => { + let mut offsets: Vec = Vec::with_capacity(values.len() + 1); + offsets.push(0); + for e in &child_exports { + offsets.push(*offsets.last().unwrap() + e.array.len() as i64); + } + let offset_buf = OffsetBuffer::::new(ScalarBuffer::from(offsets)); + let array: ArrayRef = Arc::new(LargeListArray::new( + child_field.clone(), + offset_buf, + concatenated, + None, + )); + let field = q_field( + DataType::LargeList(child_field), + false, + "list", + None, + attr_opt, + None, + ); + Ok(ArrayExport { array, field }) + } + } +} + +// --------------------------------------------------------------------------- +// Heterogeneous list → Union projection +// --------------------------------------------------------------------------- + +fn project_heterogeneous_list( + values: &[Value], + child_exports: Vec, + opts: &ProjectionOptions, +) -> ProjectionResult { + // Assign a type_id (i8) to each unique Arrow DataType in insertion order. + let mut type_id_map: Vec<(DataType, i8)> = Vec::new(); + let mut type_ids: Vec = Vec::with_capacity(values.len()); + + for export in &child_exports { + let dt = export.array.data_type().clone(); + let type_id = if let Some(pos) = type_id_map.iter().position(|(d, _)| d == &dt) { + type_id_map[pos].1 + } else { + let id = type_id_map.len() as i8; + type_id_map.push((dt, id)); + id + }; + type_ids.push(type_id); + } + + let n_types = type_id_map.len(); + + // Build UnionFields from the first occurrence of each type's field descriptor. + let mut union_fields_vec: Vec<(i8, FieldRef)> = Vec::with_capacity(n_types); + for (type_id, _) in &type_id_map { + // Find the first child export with this DataType. + let first = child_exports + .iter() + .find(|e| e.array.data_type() == type_id) + .expect("type_id_map entry must have matching export"); + union_fields_vec.push(( + type_id_map + .iter() + .find(|(d, _)| d == type_id) + .map(|(_, id)| *id) + .unwrap(), + first.field.clone(), + )); + } + let union_fields = UnionFields::try_new( + union_fields_vec.iter().map(|(id, _)| *id), + union_fields_vec.iter().map(|(_, f)| f.as_ref().clone()), + ) + .map_err(|e| ProjectionError::Arrow(e.to_string()))?; + + let arrow_union_mode = match opts.union_mode { + UnionMode::Dense => ArrowUnionMode::Dense, + UnionMode::Sparse => ArrowUnionMode::Sparse, + }; + + let union_array = match opts.union_mode { + UnionMode::Dense => { + // Dense union: per-type child arrays with per-type offsets. + let mut per_type_arrays: Vec> = vec![Vec::new(); n_types]; + let mut offsets: Vec = Vec::with_capacity(values.len()); + for (export, &type_id) in child_exports.iter().zip(type_ids.iter()) { + let type_idx = type_id as usize; + offsets.push(per_type_arrays[type_idx].len() as i32); + per_type_arrays[type_idx].push(export.array.clone()); + } + // Concatenate per-type. + let mut children: Vec = Vec::with_capacity(n_types); + for (type_idx, arrays) in per_type_arrays.iter().enumerate() { + let refs: Vec<&dyn arrow_array::Array> = + arrays.iter().map(|a| a.as_ref()).collect(); + let concatenated = if refs.is_empty() { + arrow_array::new_empty_array(&type_id_map[type_idx].0) + } else { + arrow_select::concat::concat(&refs) + .map_err(|e| ProjectionError::Arrow(e.to_string()))? + }; + children.push(concatenated); + } + UnionArray::try_new( + union_fields, + ScalarBuffer::from(type_ids), + Some(ScalarBuffer::from(offsets)), + children, + ) + .map_err(|e| ProjectionError::Arrow(e.to_string()))? + } + UnionMode::Sparse => { + // Sparse union: all children have length = n_elements. + // For each type's child, fill in the real values at the positions + // where that type was selected, and nulls elsewhere. + let n = values.len(); + let mut per_type_builders: Vec>> = vec![vec![None; n]; n_types]; + for (i, (export, &type_id)) in child_exports.iter().zip(type_ids.iter()).enumerate() { + per_type_builders[type_id as usize][i] = Some(export.array.clone()); + } + let mut children: Vec = Vec::with_capacity(n_types); + for (type_idx, slots) in per_type_builders.iter().enumerate() { + let dt = &type_id_map[type_idx].0; + let mut pieces: Vec = Vec::with_capacity(n); + for slot in slots { + if let Some(arr) = slot { + pieces.push(arr.clone()); + } else { + // Null slot: single-element null array of the right type. + let null = arrow_array::new_null_array(dt, 1); + pieces.push(null); + } + } + let refs: Vec<&dyn arrow_array::Array> = + pieces.iter().map(|a| a.as_ref()).collect(); + let concatenated = arrow_select::concat::concat(&refs) + .map_err(|e| ProjectionError::Arrow(e.to_string()))?; + children.push(concatenated); + } + UnionArray::try_new(union_fields, ScalarBuffer::from(type_ids), None, children) + .map_err(|e| ProjectionError::Arrow(e.to_string()))? + } + }; + + let union_dt = DataType::Union(union_array.fields().clone(), arrow_union_mode); + let field = q_field(union_dt, false, "list", None, None, None); + let array: ArrayRef = Arc::new(union_array); + Ok(ArrayExport { array, field }) +} + +// --------------------------------------------------------------------------- +// Dictionary projection +// --------------------------------------------------------------------------- + +fn project_dictionary( + dict: &Dictionary, + opts: &ProjectionOptions, +) -> ProjectionResult { + let keys_export = project(dict.keys(), opts)?; + let values_export = project(dict.values(), opts)?; + + let keys_field = Arc::new(keys_export.field.as_ref().clone().with_name("keys")); + let values_field = Arc::new(values_export.field.as_ref().clone().with_name("values")); + + let fields: Fields = Fields::from(vec![keys_field, values_field]); + let array: ArrayRef = Arc::new( + StructArray::try_new( + fields.clone(), + vec![keys_export.array, values_export.array], + None, + ) + .map_err(|e| ProjectionError::Arrow(e.to_string()))?, + ); + + let sorted = dict.sorted(); + let mut meta = std::collections::HashMap::new(); + meta.insert(crate::metadata::SHAPE_KEY.to_string(), "dict".to_string()); + if sorted { + meta.insert("sorted".to_string(), "true".to_string()); + } + let struct_field: FieldRef = + Arc::new(Field::new_struct("dict", fields, false).with_metadata(meta)); + Ok(ArrayExport { + array, + field: struct_field, + }) +} + +// --------------------------------------------------------------------------- +// Temporal conversion helpers +// --------------------------------------------------------------------------- + +/// Convert a q month offset (months since 2000.01) to an Arrow `Date32` value +/// (days since 1970.01.01). Returns `None` for null months. +fn q_month_to_date32(q_month: i32) -> Option { + if q_month == Q_NULL_MONTH { + return None; + } + let date = if q_month >= 0 { + MILLENNIUM.checked_add_months(Months::new(q_month as u32))? + } else { + MILLENNIUM.checked_sub_months(Months::new((-q_month) as u32))? + }; + // Days since 1970-01-01 + Some((date - UNIX_EPOCH).num_days() as i32) +} + +/// Convert a q datetime (fractional days since 2000.01.01) to Arrow +/// `TimestampMillisecond` (milliseconds since 1970.01.01). +#[inline] +fn datetime_to_ms(q_datetime: f64) -> i64 { + ((q_datetime + DATE_OFFSET_DAYS as f64) * MILLIS_PER_DAY) as i64 +} + +fn attribute_meta_str(a: Attribute) -> &'static str { + match a { + Attribute::None => "none", + Attribute::Sorted => "sorted", + Attribute::Unique => "unique", + Attribute::Parted => "parted", + Attribute::Grouped => "grouped", + } +} diff --git a/crates/qroissant-core/Cargo.toml b/crates/qroissant-core/Cargo.toml new file mode 100644 index 0000000..07957e0 --- /dev/null +++ b/crates/qroissant-core/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "qroissant-core" +version.workspace = true +edition.workspace = true +license.workspace = true +publish = false + +[lib] +name = "qroissant_core" +path = "src/lib.rs" + +[dependencies] +bytemuck = { version = "1", features = ["derive"] } +bytes = "1.11.1" +memchr = "2" +rayon = "1.10" +tokio = { workspace = true, features = ["io-util"] } +futures = { workspace = true } + diff --git a/crates/qroissant-core/src/decode.rs b/crates/qroissant-core/src/decode.rs new file mode 100644 index 0000000..a3aad59 --- /dev/null +++ b/crates/qroissant-core/src/decode.rs @@ -0,0 +1,907 @@ +use rayon::prelude::*; + +use crate::error::CoreError; +use crate::error::CoreResult; +use crate::extent::value_byte_extent; +use crate::frame::Compression; +use crate::frame::Encoding; +use crate::frame::Frame; +use crate::frame::MessageHeader; +use crate::frame::decompress_ipc_body; +use crate::protocol::Attribute; +use crate::protocol::Primitive; +use crate::protocol::Shape; +use crate::protocol::TypeCode; +use crate::protocol::ValueType; +use crate::value::Atom; +use crate::value::Dictionary; +use crate::value::List; +use crate::value::Table; +use crate::value::Value; +use crate::value::Vector; +use crate::value::VectorData; + +/// Fully decoded q IPC message. +#[derive(Clone, Debug, PartialEq)] +pub struct DecodedMessage { + header: MessageHeader, + value: Value, +} + +impl DecodedMessage { + pub fn new(header: MessageHeader, value: Value) -> Self { + Self { header, value } + } + + pub fn header(&self) -> MessageHeader { + self.header + } + + pub fn value(&self) -> &Value { + &self.value + } + + pub fn qtype(&self) -> ValueType { + self.value.qtype() + } + + pub fn into_parts(self) -> (MessageHeader, Value) { + (self.header, self.value) + } +} + +/// Options controlling how q IPC messages are decoded. +#[derive(Clone, Debug)] +pub struct DecodeOptions { + /// When `true` and the top-level value is a table with at least + /// `parallel_column_threshold` columns, columns are decoded in parallel + /// using rayon's thread pool. + pub parallel: bool, + /// Minimum number of columns required to trigger parallel decode. + pub parallel_column_threshold: usize, +} + +impl Default for DecodeOptions { + fn default() -> Self { + Self { + parallel: true, + parallel_column_threshold: 4, + } + } +} + +struct BodyReader { + bytes: bytes::Bytes, + offset: usize, +} + +impl BodyReader { + fn new(bytes: bytes::Bytes) -> Self { + Self { bytes, offset: 0 } + } + + fn remaining(&self) -> usize { + self.bytes.len().saturating_sub(self.offset) + } + + fn read_exact(&mut self) -> CoreResult<[u8; N]> { + let end = self + .offset + .checked_add(N) + .ok_or(CoreError::LengthOverflow(usize::MAX))?; + let slice = self + .bytes + .get(self.offset..end) + .ok_or_else(|| std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; + self.offset = end; + Ok(slice.try_into().expect("fixed-size slice length checked")) + } + + /// Returns a borrowed slice of `len` bytes and advances the offset. + fn read_slice(&mut self, len: usize) -> CoreResult<&[u8]> { + let end = self + .offset + .checked_add(len) + .ok_or(CoreError::LengthOverflow(usize::MAX))?; + let slice = self + .bytes + .get(self.offset..end) + .ok_or_else(|| std::io::Error::from(std::io::ErrorKind::UnexpectedEof))?; + self.offset = end; + Ok(slice) + } + + /// Returns a zero-copy Bytes wrapper of `len` bytes and advances the offset. + fn read_bytes(&mut self, len: usize) -> CoreResult { + let end = self + .offset + .checked_add(len) + .ok_or(CoreError::LengthOverflow(usize::MAX))?; + if end > self.bytes.len() { + return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof).into()); + } + let slice = self.bytes.slice(self.offset..end); + self.offset = end; + Ok(slice) + } + + /// Returns a `Bytes` wrapper of `count * size_of::()` bytes, aligned for `T`. + /// + /// If the current offset is already aligned for `T`, this is zero-copy + /// (a `Bytes::slice`). Otherwise it copies into a new aligned allocation. + fn read_bytes_aligned(&mut self, count: usize) -> CoreResult { + let byte_len = count + .checked_mul(std::mem::size_of::()) + .ok_or(CoreError::LengthOverflow(count))?; + let end = self + .offset + .checked_add(byte_len) + .ok_or(CoreError::LengthOverflow(usize::MAX))?; + if end > self.bytes.len() { + return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof).into()); + } + let ptr = self.bytes[self.offset..].as_ptr(); + let align = std::mem::align_of::(); + let result = if (ptr as usize) % align == 0 { + // Already aligned — zero-copy slice. + self.bytes.slice(self.offset..end) + } else { + // Misaligned — must copy into an aligned allocation. + bytes::Bytes::copy_from_slice(&self.bytes[self.offset..end]) + }; + self.offset = end; + Ok(result) + } + + fn read_u8(&mut self) -> CoreResult { + Ok(self.read_exact::<1>()?[0]) + } + + fn read_i8(&mut self) -> CoreResult { + Ok(self.read_u8()? as i8) + } + + fn read_i16(&mut self) -> CoreResult { + Ok(i16::from_le_bytes(self.read_exact::<2>()?)) + } + + fn read_i32(&mut self) -> CoreResult { + Ok(i32::from_le_bytes(self.read_exact::<4>()?)) + } + + fn read_i64(&mut self) -> CoreResult { + Ok(i64::from_le_bytes(self.read_exact::<8>()?)) + } + + fn read_f32(&mut self) -> CoreResult { + Ok(f32::from_le_bytes(self.read_exact::<4>()?)) + } + + fn read_f64(&mut self) -> CoreResult { + Ok(f64::from_le_bytes(self.read_exact::<8>()?)) + } + + fn read_guid(&mut self) -> CoreResult<[u8; 16]> { + self.read_exact::<16>() + } + + fn read_length(&mut self) -> CoreResult { + let length = self.read_i32()?; + usize::try_from(length).map_err(|_| CoreError::InvalidCollectionLength(length)) + } + + fn read_symbol(&mut self) -> CoreResult { + let remaining = &self.bytes[self.offset..]; + match memchr::memchr(0, remaining) { + Some(pos) => { + let symbol = self.bytes.slice(self.offset..self.offset + pos); + self.offset += pos + 1; + Ok(symbol) + } + None => Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof).into()), + } + } + + /// Reads `count` elements of a fixed-width type as a bulk memcpy. + /// + /// The wire bytes are reinterpreted directly into the target `Vec` via + /// `bytemuck::cast_slice_mut`, avoiding per-element parsing. This is valid + /// because we only support little-endian payloads and all target platforms + /// are little-endian. + fn read_vec( + &mut self, + count: usize, + ) -> CoreResult> { + let byte_len = count + .checked_mul(std::mem::size_of::()) + .ok_or(CoreError::LengthOverflow(count))?; + let bytes = self.read_slice(byte_len)?; + let mut values = vec![T::zeroed(); count]; + let dst: &mut [u8] = bytemuck::cast_slice_mut(&mut values); + dst.copy_from_slice(bytes); + Ok(values) + } +} + +fn decode_atom(reader: &mut BodyReader, primitive: Primitive) -> CoreResult { + Ok(match primitive { + Primitive::Boolean => Atom::Boolean(reader.read_u8()? != 0), + Primitive::Guid => Atom::Guid(reader.read_guid()?), + Primitive::Byte => Atom::Byte(reader.read_u8()?), + Primitive::Short => Atom::Short(reader.read_i16()?), + Primitive::Int => Atom::Int(reader.read_i32()?), + Primitive::Long => Atom::Long(reader.read_i64()?), + Primitive::Real => Atom::Real(reader.read_f32()?), + Primitive::Float => Atom::Float(reader.read_f64()?), + Primitive::Char => Atom::Char(reader.read_u8()?), + Primitive::Symbol => Atom::Symbol(reader.read_symbol()?), + Primitive::Timestamp => Atom::Timestamp(reader.read_i64()?), + Primitive::Month => Atom::Month(reader.read_i32()?), + Primitive::Date => Atom::Date(reader.read_i32()?), + Primitive::Datetime => Atom::Datetime(reader.read_f64()?), + Primitive::Timespan => Atom::Timespan(reader.read_i64()?), + Primitive::Minute => Atom::Minute(reader.read_i32()?), + Primitive::Second => Atom::Second(reader.read_i32()?), + Primitive::Time => Atom::Time(reader.read_i32()?), + Primitive::Mixed => unreachable!("mixed values are not encoded as atoms"), + }) +} + +fn decode_vector( + reader: &mut BodyReader, + primitive: Primitive, + attribute: Attribute, + length: usize, +) -> CoreResult { + let data = match primitive { + Primitive::Boolean => VectorData::Boolean(reader.read_bytes(length)?), + Primitive::Guid => VectorData::Guid( + reader.read_bytes( + length + .checked_mul(16) + .ok_or(CoreError::LengthOverflow(length))?, + )?, + ), + Primitive::Byte => VectorData::Byte(reader.read_bytes(length)?), + Primitive::Short => VectorData::Short(reader.read_bytes_aligned::(length)?), + Primitive::Int => VectorData::Int(reader.read_bytes_aligned::(length)?), + Primitive::Long => VectorData::Long(reader.read_bytes_aligned::(length)?), + Primitive::Real => VectorData::Real(reader.read_bytes_aligned::(length)?), + Primitive::Float => VectorData::Float(reader.read_bytes_aligned::(length)?), + Primitive::Char => VectorData::Char(reader.read_bytes(length)?), + Primitive::Symbol => { + let mut values = Vec::with_capacity(length); + for _ in 0..length { + values.push(reader.read_symbol()?); + } + VectorData::Symbol(values) + } + Primitive::Timestamp => VectorData::Timestamp(reader.read_bytes_aligned::(length)?), + Primitive::Month => VectorData::Month(reader.read_bytes_aligned::(length)?), + Primitive::Date => VectorData::Date(reader.read_bytes_aligned::(length)?), + Primitive::Datetime => VectorData::Datetime(reader.read_bytes_aligned::(length)?), + Primitive::Timespan => VectorData::Timespan(reader.read_bytes_aligned::(length)?), + Primitive::Minute => VectorData::Minute(reader.read_bytes_aligned::(length)?), + Primitive::Second => VectorData::Second(reader.read_bytes_aligned::(length)?), + Primitive::Time => VectorData::Time(reader.read_bytes_aligned::(length)?), + Primitive::Mixed => unreachable!("mixed values are not encoded as vectors"), + }; + + Ok(Vector::new(attribute, data)) +} + +pub(crate) fn extract_symbol_names(value: &Value) -> CoreResult> { + match value { + Value::Vector(vector) => match vector.data() { + VectorData::Symbol(values) => Ok(values.clone()), + _ => Err(CoreError::InvalidStructure( + "q table column names must be a symbol vector".to_string(), + )), + }, + _ => Err(CoreError::InvalidStructure( + "q table column names must be encoded as a symbol vector".to_string(), + )), + } +} + +pub(crate) fn extract_columns(value: &Value) -> CoreResult> { + match value { + Value::List(list) => Ok(list.values().to_vec()), + _ => Err(CoreError::InvalidStructure( + "q table columns must be encoded as a general list".to_string(), + )), + } +} + +fn decode_inner(reader: &mut BodyReader) -> CoreResult { + let type_code = TypeCode::try_from(reader.read_i8()?)?; + match type_code.shape() { + Shape::Atom => Ok(Value::Atom(decode_atom( + reader, + type_code + .primitive() + .expect("atom types always have a primitive"), + )?)), + Shape::Vector => { + let attribute = Attribute::try_from(reader.read_i8()?)?; + let length = reader.read_length()?; + Ok(Value::Vector(decode_vector( + reader, + type_code + .primitive() + .expect("vector types always have a primitive"), + attribute, + length, + )?)) + } + Shape::List => { + let attribute = Attribute::try_from(reader.read_i8()?)?; + let length = reader.read_length()?; + let mut values = Vec::with_capacity(length); + for _ in 0..length { + values.push(decode_inner(reader)?); + } + Ok(Value::List(List::new(attribute, values))) + } + Shape::Dictionary => { + let sorted = matches!(type_code, TypeCode::SortedDictionary); + let keys = decode_inner(reader)?; + let values = decode_inner(reader)?; + let dictionary = Dictionary::new(sorted, keys, values); + dictionary.validate()?; + Ok(Value::Dictionary(dictionary)) + } + Shape::Table => { + let attribute = Attribute::try_from(reader.read_i8()?)?; + let encoded_dictionary = decode_inner(reader)?; + let Value::Dictionary(dictionary) = encoded_dictionary else { + return Err(CoreError::InvalidStructure( + "q table payload must contain a dictionary body".to_string(), + )); + }; + let column_names = extract_symbol_names(dictionary.keys())?; + let columns = extract_columns(dictionary.values())?; + let table = Table::new(attribute, column_names, columns); + table.validate()?; + Ok(Value::Table(table)) + } + Shape::UnaryPrimitive => Ok(Value::UnaryPrimitive { + opcode: reader.read_i8()?, + }), + Shape::Error => { + let error_msg = reader.read_symbol()?; + Err(CoreError::QRuntime( + String::from_utf8_lossy(&error_msg).into(), + )) + } + } +} + +/// Parsed table preamble: everything before the column data. +struct TablePreamble { + attribute: Attribute, + column_names: Vec, + /// Byte offset within the body where column values start (past the + /// general-list header). + columns_start: usize, + num_columns: usize, +} + +/// Parses the table header, dictionary keys (column names), and list header. +/// +/// Shared by both the sequential and parallel table decode paths. +fn parse_table_preamble(body: &bytes::Bytes) -> CoreResult { + let mut reader = BodyReader::new(body.clone()); + + // Table: type(1) + attribute(1) + let _type_code = reader.read_i8()?; // 98 = Table + let attribute = Attribute::try_from(reader.read_i8()?)?; + + // Dictionary: type(1) + keys + values + let dict_type = TypeCode::try_from(reader.read_i8()?)?; + if !matches!(dict_type, TypeCode::Dictionary | TypeCode::SortedDictionary) { + return Err(CoreError::InvalidStructure( + "q table payload must contain a dictionary body".to_string(), + )); + } + + // Keys = symbol vector (column names) + let keys = decode_inner(&mut reader)?; + let column_names = extract_symbol_names(&keys)?; + + // Values = general list: type(1) + attr(1) + length(4) + column values + let list_type = reader.read_i8()?; + if list_type != 0 { + return Err(CoreError::InvalidStructure( + "q table columns must be encoded as a general list".to_string(), + )); + } + let _list_attr = reader.read_i8()?; + let num_columns = reader.read_length()?; + + if num_columns != column_names.len() { + return Err(CoreError::InvalidStructure(format!( + "table has {} column names but {} column values", + column_names.len(), + num_columns + ))); + } + + Ok(TablePreamble { + attribute, + column_names, + columns_start: reader.offset, + num_columns, + }) +} + +/// Attempts parallel table decode. Returns `None` if the column count is +/// below the threshold, allowing the caller to fall back to sequential. +fn try_decode_table_parallel(body: bytes::Bytes, threshold: usize) -> CoreResult> { + let preamble = parse_table_preamble(&body)?; + + if preamble.num_columns < threshold { + return Ok(None); + } + + // Use value_byte_extent to find each column's byte range without parsing + let mut column_ranges: Vec<(usize, usize)> = Vec::with_capacity(preamble.num_columns); + let mut scan = preamble.columns_start; + for _ in 0..preamble.num_columns { + let extent = value_byte_extent(&body, scan)?; + column_ranges.push((scan, scan + extent)); + scan += extent; + } + + // Parallel decode: each column gets its own byte slice + let columns: Vec> = column_ranges + .par_iter() + .map(|&(start, end)| { + let mut col_reader = BodyReader::new(body.slice(start..end)); + let value = decode_inner(&mut col_reader)?; + if col_reader.remaining() != 0 { + return Err(CoreError::TrailingBodyBytes(col_reader.remaining())); + } + Ok(value) + }) + .collect(); + + let columns: Vec = columns.into_iter().collect::>>()?; + + let table = Table::new(preamble.attribute, preamble.column_names, columns); + table.validate()?; + Ok(Some(Value::Table(table))) +} + +/// Decodes one q value body from a little-endian byte slice. +/// +/// Returns `UnsupportedEndianness` for big-endian payloads. +pub fn decode_value(body: bytes::Bytes, encoding: Encoding) -> CoreResult { + decode_value_with_options(body, encoding, &DecodeOptions::default()) +} + +/// Decodes one q value body with configurable options. +/// +/// When `options.parallel` is `true` and the body contains a table with +/// enough columns, columns are decoded in parallel using rayon. +pub fn decode_value_with_options( + body: bytes::Bytes, + encoding: Encoding, + options: &DecodeOptions, +) -> CoreResult { + if encoding != Encoding::LittleEndian { + return Err(CoreError::UnsupportedEndianness(encoding)); + } + + // Fast path: parallel table decode + if options.parallel && body.first() == Some(&98) { + if let Some(table) = + try_decode_table_parallel(body.clone(), options.parallel_column_threshold)? + { + return Ok(table); + } + } + + let mut reader = BodyReader::new(body); + let value = decode_inner(&mut reader)?; + if reader.remaining() != 0 { + return Err(CoreError::TrailingBodyBytes(reader.remaining())); + } + Ok(value) +} + +/// Decodes a full q IPC frame into its header and value. +/// +/// Returns `UnsupportedEndianness` for big-endian payloads. +pub fn decode_message(frame_bytes: bytes::Bytes) -> CoreResult { + decode_message_with_options(frame_bytes, &DecodeOptions::default()) +} + +/// Decodes a full q IPC frame with configurable options. +pub fn decode_message_with_options( + frame_bytes: bytes::Bytes, + options: &DecodeOptions, +) -> CoreResult { + let frame = Frame::parse(&frame_bytes)?; + let header = frame.header(); + + if header.encoding() != Encoding::LittleEndian { + return Err(CoreError::UnsupportedEndianness(header.encoding())); + } + + if header.compression() != Compression::Uncompressed { + let decompressed = decompress_ipc_body(frame.body(), header.encoding())?; + let value = decode_value_with_options( + bytes::Bytes::from(decompressed), + header.encoding(), + options, + )?; + return Ok(DecodedMessage::new(header, value)); + } + + let value = decode_value_with_options( + frame_bytes.slice(crate::frame::HEADER_LEN..), + header.encoding(), + options, + )?; + Ok(DecodedMessage::new(header, value)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::protocol::Attribute; + + #[test] + fn decode_int_atom_body() { + let value = decode_value( + bytes::Bytes::from(vec![i8::from(TypeCode::IntAtom) as u8, 42, 0, 0, 0]), + Encoding::LittleEndian, + ) + .unwrap(); + + assert_eq!(value, Value::Atom(Atom::Int(42))); + assert_eq!(value.qtype(), ValueType::atom(Primitive::Int)); + } + + #[test] + fn decode_int_vector_body() { + let value = decode_value( + bytes::Bytes::from_static(&[6_u8, 1, 3, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0]), + Encoding::LittleEndian, + ) + .unwrap(); + + assert_eq!( + value, + Value::Vector(Vector::new( + Attribute::Sorted, + VectorData::from_i32s(&[1, 2, 3]), + )) + ); + } + + #[test] + fn decode_symbol_atom_body() { + let value = decode_value( + bytes::Bytes::from_static(&[245_u8, b'a', b'b', 0]), + Encoding::LittleEndian, + ) + .unwrap(); + + assert_eq!( + value, + Value::Atom(Atom::Symbol(bytes::Bytes::from_static(b"ab"))) + ); + } + + #[test] + fn decode_list_body() { + let value = decode_value( + bytes::Bytes::from_static(&[0_u8, 0, 2, 0, 0, 0, 250, 42, 0, 0, 0, 245, b'a', b'b', 0]), + Encoding::LittleEndian, + ) + .unwrap(); + + assert_eq!( + value, + Value::List(List::new( + Attribute::None, + vec![ + Value::Atom(Atom::Int(42)), + Value::Atom(Atom::Symbol(bytes::Bytes::from_static(b"ab"))) + ], + )) + ); + } + + #[test] + fn decode_dictionary_body() { + let value = decode_value( + bytes::Bytes::from_static(&[ + 99_u8, 11, 0, 2, 0, 0, 0, b'a', 0, b'b', 0, 6, 0, 2, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, + 0, + ]), + Encoding::LittleEndian, + ) + .unwrap(); + + assert_eq!( + value, + Value::Dictionary(Dictionary::new( + false, + Value::Vector(Vector::new( + Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b") + ]), + )), + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[1, 2]),)), + )) + ); + } + + #[test] + fn decode_table_body() { + let value = decode_value( + bytes::Bytes::from_static(&[ + 98_u8, 0, 99, 11, 0, 2, 0, 0, 0, b's', b'y', b'm', 0, b'p', b'x', 0, 0, 0, 2, 0, 0, + 0, 11, 0, 2, 0, 0, 0, b'a', 0, b'b', 0, 6, 0, 2, 0, 0, 0, 10, 0, 0, 0, 20, 0, 0, 0, + ]), + Encoding::LittleEndian, + ) + .unwrap(); + + assert_eq!( + value, + Value::Table(Table::new( + Attribute::None, + vec![ + bytes::Bytes::from_static(b"sym"), + bytes::Bytes::from_static(b"px") + ], + vec![ + Value::Vector(Vector::new( + Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b") + ]), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_i32s(&[10, 20]), + )), + ], + )) + ); + } + + #[test] + fn decode_unary_primitive_body() { + let value = decode_value( + bytes::Bytes::from_static(&[101_u8, 0]), + Encoding::LittleEndian, + ) + .unwrap(); + + assert_eq!(value, Value::UnaryPrimitive { opcode: 0 }); + } + + #[test] + fn decode_rejects_trailing_bytes() { + assert!(matches!( + decode_value( + bytes::Bytes::from_static(&[250_u8, 42, 0, 0, 0, 99]), + Encoding::LittleEndian + ), + Err(CoreError::TrailingBodyBytes(1)) + )); + } + + #[test] + fn decode_rejects_malformed_table_structure() { + let err = decode_value( + bytes::Bytes::from_static(&[ + 98_u8, 0, 99, 11, 0, 1, 0, 0, 0, b'x', 0, 250, 42, 0, 0, 0, + ]), + Encoding::LittleEndian, + ) + .unwrap_err(); + + assert!(matches!(err, CoreError::InvalidStructure(_))); + } + + #[test] + fn decode_rejects_big_endian() { + assert!(matches!( + decode_value( + bytes::Bytes::from_static(&[250_u8, 0, 0, 0, 42]), + Encoding::BigEndian + ), + Err(CoreError::UnsupportedEndianness(Encoding::BigEndian)) + )); + } + + // -- Parallel decode tests -- + + use crate::encode::encode_value; + + /// Helper: encode a table, decode with parallel=true and parallel=false, + /// and verify both produce identical results. + fn assert_parallel_matches_sequential(table: &Value) { + let body = encode_value(table, Encoding::LittleEndian).unwrap(); + + let seq_opts = DecodeOptions { + parallel: false, + ..Default::default() + }; + let par_opts = DecodeOptions { + parallel: true, + parallel_column_threshold: 1, // force parallel even for small tables + }; + + let seq = decode_value_with_options( + bytes::Bytes::from(body.clone()), + Encoding::LittleEndian, + &seq_opts, + ) + .unwrap(); + let par = decode_value_with_options( + bytes::Bytes::from(body.clone()), + Encoding::LittleEndian, + &par_opts, + ) + .unwrap(); + + assert_eq!(seq, par, "parallel decode must match sequential decode"); + assert_eq!(&seq, table, "decoded value must match original"); + } + + #[test] + fn parallel_decode_multi_column_table() { + let table = Value::Table(Table::new( + Attribute::None, + vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + bytes::Bytes::from_static(b"c"), + bytes::Bytes::from_static(b"d"), + ], + vec![ + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_i32s(&[1, 2, 3]), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"x"), + bytes::Bytes::from_static(b"y"), + bytes::Bytes::from_static(b"z"), + ]), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_f64s(&[1.0, 2.0, 3.0]), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_i64s(&[100, 200, 300]), + )), + ], + )); + assert_parallel_matches_sequential(&table); + } + + #[test] + fn parallel_decode_mixed_type_columns() { + let table = Value::Table(Table::new( + Attribute::None, + vec![ + bytes::Bytes::from_static(b"bools"), + bytes::Bytes::from_static(b"guids"), + bytes::Bytes::from_static(b"chars"), + bytes::Bytes::from_static(b"times"), + bytes::Bytes::from_static(b"dates"), + ], + vec![ + Value::Vector(Vector::new( + Attribute::None, + VectorData::Boolean(bytes::Bytes::from_static(&[1, 0])), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_guids(&[[0u8; 16], [1u8; 16]]), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::Char(bytes::Bytes::from_static(b"ab")), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_times(&[1000, 2000]), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_dates(&[100, 200]), + )), + ], + )); + assert_parallel_matches_sequential(&table); + } + + #[test] + fn parallel_decode_below_threshold_falls_back_to_sequential() { + // 2 columns, threshold 4 → should use sequential path + let table = Value::Table(Table::new( + Attribute::None, + vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + ], + vec![ + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[1, 2]))), + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[3, 4]))), + ], + )); + let body = encode_value(&table, Encoding::LittleEndian).unwrap(); + let opts = DecodeOptions { + parallel: true, + parallel_column_threshold: 4, + }; + let decoded = decode_value_with_options( + bytes::Bytes::from(body.clone()), + Encoding::LittleEndian, + &opts, + ) + .unwrap(); + assert_eq!(decoded, table); + } + + #[test] + fn parallel_decode_non_table_ignores_parallel_flag() { + // Non-table values should decode normally regardless of parallel flag + let value = Value::Atom(Atom::Int(42)); + let body = encode_value(&value, Encoding::LittleEndian).unwrap(); + let opts = DecodeOptions { + parallel: true, + parallel_column_threshold: 1, + }; + let decoded = decode_value_with_options( + bytes::Bytes::from(body.clone()), + Encoding::LittleEndian, + &opts, + ) + .unwrap(); + assert_eq!(decoded, value); + } + + #[test] + fn parse_table_preamble_correct() { + let table = Value::Table(Table::new( + Attribute::None, + vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + bytes::Bytes::from_static(b"c"), + bytes::Bytes::from_static(b"d"), + bytes::Bytes::from_static(b"e"), + ], + vec![ + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[1]))), + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[2]))), + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[3]))), + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[4]))), + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[5]))), + ], + )); + let body = encode_value(&table, Encoding::LittleEndian).unwrap(); + let preamble = parse_table_preamble(&bytes::Bytes::from(body)).unwrap(); + assert_eq!(preamble.num_columns, 5); + assert_eq!(preamble.column_names.len(), 5); + assert_eq!(&preamble.column_names[0][..], b"a"); + assert_eq!(&preamble.column_names[4][..], b"e"); + } +} diff --git a/crates/qroissant-core/src/encode.rs b/crates/qroissant-core/src/encode.rs new file mode 100644 index 0000000..f3d6299 --- /dev/null +++ b/crates/qroissant-core/src/encode.rs @@ -0,0 +1,385 @@ +use crate::error::CoreError; +use crate::error::CoreResult; +use crate::frame::Compression; +use crate::frame::Encoding; +use crate::frame::MessageType; +use crate::frame::serialize_body_as_message; +use crate::protocol::TypeCode; +use crate::value::Atom; +use crate::value::List; +use crate::value::Table; +use crate::value::Value; +use crate::value::Vector; +use crate::value::VectorData; + +fn push_i16(buffer: &mut Vec, value: i16) { + buffer.extend_from_slice(&value.to_le_bytes()); +} + +fn push_i32(buffer: &mut Vec, value: i32) { + buffer.extend_from_slice(&value.to_le_bytes()); +} + +fn push_i64(buffer: &mut Vec, value: i64) { + buffer.extend_from_slice(&value.to_le_bytes()); +} + +fn push_f32(buffer: &mut Vec, value: f32) { + buffer.extend_from_slice(&value.to_le_bytes()); +} + +fn push_f64(buffer: &mut Vec, value: f64) { + buffer.extend_from_slice(&value.to_le_bytes()); +} + +fn push_length(buffer: &mut Vec, value: usize) { + let value = i32::try_from(value).expect("supported q vectors fit in 32-bit length"); + push_i32(buffer, value); +} + +fn encode_atom(atom: &Atom, buffer: &mut Vec) { + match atom { + Atom::Boolean(value) => { + buffer.push(TypeCode::BooleanAtom as i8 as u8); + buffer.push(u8::from(*value)); + } + Atom::Guid(value) => { + buffer.push(TypeCode::GuidAtom as i8 as u8); + buffer.extend_from_slice(value); + } + Atom::Byte(value) => { + buffer.push(TypeCode::ByteAtom as i8 as u8); + buffer.push(*value); + } + Atom::Short(value) => { + buffer.push(TypeCode::ShortAtom as i8 as u8); + push_i16(buffer, *value); + } + Atom::Int(value) => { + buffer.push(TypeCode::IntAtom as i8 as u8); + push_i32(buffer, *value); + } + Atom::Long(value) => { + buffer.push(TypeCode::LongAtom as i8 as u8); + push_i64(buffer, *value); + } + Atom::Real(value) => { + buffer.push(TypeCode::RealAtom as i8 as u8); + push_f32(buffer, *value); + } + Atom::Float(value) => { + buffer.push(TypeCode::FloatAtom as i8 as u8); + push_f64(buffer, *value); + } + Atom::Char(value) => { + buffer.push(TypeCode::CharAtom as i8 as u8); + buffer.push(*value); + } + Atom::Symbol(value) => { + buffer.push(TypeCode::SymbolAtom as i8 as u8); + buffer.extend_from_slice(value); + buffer.push(0); + } + Atom::Timestamp(value) => { + buffer.push(TypeCode::TimestampAtom as i8 as u8); + push_i64(buffer, *value); + } + Atom::Month(value) => { + buffer.push(TypeCode::MonthAtom as i8 as u8); + push_i32(buffer, *value); + } + Atom::Date(value) => { + buffer.push(TypeCode::DateAtom as i8 as u8); + push_i32(buffer, *value); + } + Atom::Datetime(value) => { + buffer.push(TypeCode::DatetimeAtom as i8 as u8); + push_f64(buffer, *value); + } + Atom::Timespan(value) => { + buffer.push(TypeCode::TimespanAtom as i8 as u8); + push_i64(buffer, *value); + } + Atom::Minute(value) => { + buffer.push(TypeCode::MinuteAtom as i8 as u8); + push_i32(buffer, *value); + } + Atom::Second(value) => { + buffer.push(TypeCode::SecondAtom as i8 as u8); + push_i32(buffer, *value); + } + Atom::Time(value) => { + buffer.push(TypeCode::TimeAtom as i8 as u8); + push_i32(buffer, *value); + } + } +} + +fn encode_vector(vector: &Vector, buffer: &mut Vec) { + let attribute = i8::from(vector.attribute()) as u8; + let data = vector.data(); + let len = data.len(); + + // All non-Symbol variants store raw Bytes; pick the type code, write header + raw bytes. + let (type_code, raw) = match data { + VectorData::Boolean(b) => (TypeCode::BooleanVector, Some(b)), + VectorData::Guid(b) => (TypeCode::GuidVector, Some(b)), + VectorData::Byte(b) => (TypeCode::ByteVector, Some(b)), + VectorData::Short(b) => (TypeCode::ShortVector, Some(b)), + VectorData::Int(b) => (TypeCode::IntVector, Some(b)), + VectorData::Long(b) => (TypeCode::LongVector, Some(b)), + VectorData::Real(b) => (TypeCode::RealVector, Some(b)), + VectorData::Float(b) => (TypeCode::FloatVector, Some(b)), + VectorData::Char(b) => (TypeCode::CharVector, Some(b)), + VectorData::Timestamp(b) => (TypeCode::TimestampVector, Some(b)), + VectorData::Month(b) => (TypeCode::MonthVector, Some(b)), + VectorData::Date(b) => (TypeCode::DateVector, Some(b)), + VectorData::Datetime(b) => (TypeCode::DatetimeVector, Some(b)), + VectorData::Timespan(b) => (TypeCode::TimespanVector, Some(b)), + VectorData::Minute(b) => (TypeCode::MinuteVector, Some(b)), + VectorData::Second(b) => (TypeCode::SecondVector, Some(b)), + VectorData::Time(b) => (TypeCode::TimeVector, Some(b)), + VectorData::Symbol(_) => (TypeCode::SymbolVector, None), + }; + + buffer.push(type_code as i8 as u8); + buffer.push(attribute); + push_length(buffer, len); + + if let Some(raw) = raw { + buffer.extend_from_slice(raw); + } else if let VectorData::Symbol(values) = data { + for value in values { + buffer.extend_from_slice(value); + buffer.push(0); + } + } +} + +fn encode_table(table: &Table, buffer: &mut Vec) -> CoreResult<()> { + buffer.push(TypeCode::Table as i8 as u8); + buffer.push(i8::from(table.attribute()) as u8); + + buffer.push(TypeCode::Dictionary as i8 as u8); + buffer.push(TypeCode::SymbolVector as i8 as u8); + buffer.push(0); + push_length(buffer, table.column_names().len()); + for name in table.column_names() { + buffer.extend_from_slice(name); + buffer.push(0); + } + + buffer.push(TypeCode::GeneralList as i8 as u8); + buffer.push(0); + push_length(buffer, table.columns().len()); + for column in table.columns() { + encode_value_into(column, buffer)?; + } + + Ok(()) +} + +fn encode_list(list: &List, buffer: &mut Vec) -> CoreResult<()> { + buffer.push(TypeCode::GeneralList as i8 as u8); + buffer.push(i8::from(list.attribute()) as u8); + push_length(buffer, list.len()); + for value in list.values() { + encode_value_into(value, buffer)?; + } + + Ok(()) +} + +fn encode_value_into(value: &Value, buffer: &mut Vec) -> CoreResult<()> { + match value { + Value::Atom(atom) => encode_atom(atom, buffer), + Value::Vector(vector) => encode_vector(vector, buffer), + Value::List(list) => encode_list(list, buffer)?, + Value::Dictionary(dictionary) => { + dictionary.validate()?; + buffer.push(if dictionary.sorted() { + TypeCode::SortedDictionary as i8 as u8 + } else { + TypeCode::Dictionary as i8 as u8 + }); + encode_value_into(dictionary.keys(), buffer)?; + encode_value_into(dictionary.values(), buffer)?; + } + Value::Table(table) => { + table.validate()?; + encode_table(table, buffer)?; + } + Value::UnaryPrimitive { opcode } => { + buffer.push(TypeCode::UnaryPrimitive as i8 as u8); + buffer.push(*opcode as u8); + } + } + + Ok(()) +} + +/// Encodes a supported q value as a little-endian q IPC body. +/// +/// Returns `UnsupportedEndianness` for big-endian encoding. +pub fn encode_value(value: &Value, encoding: Encoding) -> CoreResult> { + if encoding != Encoding::LittleEndian { + return Err(CoreError::UnsupportedEndianness(encoding)); + } + let mut buffer = Vec::new(); + encode_value_into(value, &mut buffer)?; + Ok(buffer) +} + +/// Encodes a supported q value as a full q IPC message. +/// +/// Returns `UnsupportedEndianness` for big-endian encoding. +pub fn encode_message( + value: &Value, + encoding: Encoding, + message_type: MessageType, + compression: Compression, +) -> CoreResult> { + let body = encode_value(value, encoding)?; + serialize_body_as_message(&body, encoding, message_type, compression) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::decode::decode_value; + use crate::protocol::Attribute; + use crate::value::Dictionary; + use crate::value::List; + use crate::value::Table; + + #[test] + fn encode_int_atom_body() { + let value = Value::Atom(Atom::Int(42)); + let body = encode_value(&value, Encoding::LittleEndian).unwrap(); + + assert_eq!(body, vec![250, 42, 0, 0, 0]); + assert_eq!( + decode_value(bytes::Bytes::from(body.clone()), Encoding::LittleEndian).unwrap(), + value + ); + } + + #[test] + fn encode_rejects_big_endian() { + let value = Value::Vector(Vector::new( + Attribute::Sorted, + VectorData::from_i32s(&[1, 2, 3]), + )); + assert!(matches!( + encode_value(&value, Encoding::BigEndian), + Err(CoreError::UnsupportedEndianness(Encoding::BigEndian)) + )); + } + + #[test] + fn encode_symbol_vector_body() { + let value = Value::Vector(Vector::new( + Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"alpha"), + bytes::Bytes::from_static(b"beta"), + ]), + )); + let body = encode_value(&value, Encoding::LittleEndian).unwrap(); + + assert_eq!( + body, + bytes::Bytes::from_static(b"\x0b\x00\x02\0\0\0alpha\0beta\0") + ); + assert_eq!( + decode_value(bytes::Bytes::from(body.clone()), Encoding::LittleEndian).unwrap(), + value + ); + } + + #[test] + fn encode_list_body() { + let value = Value::List(List::new( + Attribute::None, + vec![ + Value::Atom(Atom::Int(42)), + Value::Atom(Atom::Symbol(bytes::Bytes::from_static(b"ab"))), + ], + )); + let body = encode_value(&value, Encoding::LittleEndian).unwrap(); + + assert_eq!( + decode_value(bytes::Bytes::from(body.clone()), Encoding::LittleEndian).unwrap(), + value + ); + } + + #[test] + fn encode_dictionary_body() { + let value = Value::Dictionary(Dictionary::new( + false, + Value::Vector(Vector::new( + Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + ]), + )), + Value::Vector(Vector::new(Attribute::None, VectorData::from_i32s(&[1, 2]))), + )); + let body = encode_value(&value, Encoding::LittleEndian).unwrap(); + + assert_eq!( + decode_value(bytes::Bytes::from(body.clone()), Encoding::LittleEndian).unwrap(), + value + ); + } + + #[test] + fn encode_table_body() { + let value = Value::Table(Table::new( + Attribute::None, + vec![ + bytes::Bytes::from_static(b"sym"), + bytes::Bytes::from_static(b"px"), + ], + vec![ + Value::Vector(Vector::new( + Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + ]), + )), + Value::Vector(Vector::new( + Attribute::None, + VectorData::from_i32s(&[10, 20]), + )), + ], + )); + let body = encode_value(&value, Encoding::LittleEndian).unwrap(); + + assert_eq!( + decode_value(bytes::Bytes::from(body.clone()), Encoding::LittleEndian).unwrap(), + value + ); + } + + #[test] + fn encode_rejects_malformed_table_structure() { + let value = Value::Table(Table::new( + crate::protocol::Attribute::None, + vec![ + bytes::Bytes::from_static(b"sym"), + bytes::Bytes::from_static(b"px"), + ], + vec![Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Symbol(vec![bytes::Bytes::from_static(b"a")]), + ))], + )); + + let err = encode_value(&value, Encoding::LittleEndian).unwrap_err(); + assert!(matches!(err, crate::error::CoreError::InvalidStructure(_))); + } +} diff --git a/crates/qroissant-core/src/error.rs b/crates/qroissant-core/src/error.rs new file mode 100644 index 0000000..0f6c933 --- /dev/null +++ b/crates/qroissant-core/src/error.rs @@ -0,0 +1,112 @@ +use std::error::Error; +use std::fmt; + +use crate::frame::Compression; +use crate::frame::Encoding; + +/// Core result type used across the qroissant core crate. +pub type CoreResult = Result; + +/// Errors produced by low-level q IPC frame handling. +#[derive(Debug)] +pub enum CoreError { + InvalidEncoding(u8), + InvalidMessageType(u8), + InvalidCompression(u8), + InvalidAttribute(i8), + InvalidTypeCode(i8), + InvalidMessageLength(usize), + InvalidCollectionLength(i32), + InvalidStructure(String), + TruncatedHeader { actual: usize }, + FrameLengthMismatch { declared: usize, actual: usize }, + TrailingBodyBytes(usize), + UnsupportedEndianness(Encoding), + UnsupportedCompression(Compression), + UnsupportedTypeCode(i8), + LengthOverflow(usize), + Io(std::io::Error), + QRuntime(String), +} + +impl fmt::Display for CoreError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidEncoding(value) => write!( + f, + "invalid q IPC encoding value {value}; expected 0 (big-endian) or 1 (little-endian)" + ), + Self::InvalidMessageType(value) => write!( + f, + "invalid q IPC message type value {value}; expected 0 (asynchronous), 1 (synchronous), or 2 (response)" + ), + Self::InvalidCompression(value) => write!( + f, + "invalid q IPC compression value {value}; expected 0 (uncompressed), 1 (compressed), or 2 (compressed large)" + ), + Self::InvalidAttribute(value) => write!( + f, + "invalid q attribute value {value}; expected 0 (none), 1 (sorted), 2 (unique), 3 (parted), or 4 (grouped)" + ), + Self::InvalidTypeCode(value) => write!(f, "invalid q IPC type code {value}"), + Self::InvalidMessageLength(length) => { + write!( + f, + "invalid q IPC message length {length}; minimum is 8 bytes" + ) + } + Self::InvalidCollectionLength(length) => { + write!( + f, + "invalid q collection length {length}; length must be non-negative" + ) + } + Self::InvalidStructure(message) => write!(f, "{message}"), + Self::TruncatedHeader { actual } => write!( + f, + "truncated q IPC header: expected 8 bytes, received {actual}" + ), + Self::FrameLengthMismatch { declared, actual } => write!( + f, + "q IPC header declares {declared} bytes, but frame contains {actual}" + ), + Self::TrailingBodyBytes(remaining) => write!( + f, + "q IPC body contains {remaining} trailing bytes after the decoded value" + ), + Self::UnsupportedEndianness(encoding) => write!( + f, + "serialization currently supports only little-endian q IPC frames, got {encoding:?}" + ), + Self::UnsupportedCompression(compression) => write!( + f, + "serialization currently supports only uncompressed q IPC frames, got {compression:?}" + ), + Self::UnsupportedTypeCode(value) => write!( + f, + "q IPC type code {value} is valid but not implemented yet in the current decoder" + ), + Self::LengthOverflow(length) => write!( + f, + "q IPC frame length {length} exceeds 32-bit header capacity" + ), + Self::Io(error) => error.fmt(f), + Self::QRuntime(message) => write!(f, "q runtime error: {message}"), + } + } +} + +impl Error for CoreError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + match self { + Self::Io(error) => Some(error), + _ => None, + } + } +} + +impl From for CoreError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} diff --git a/crates/qroissant-core/src/extent.rs b/crates/qroissant-core/src/extent.rs new file mode 100644 index 0000000..5ce992a --- /dev/null +++ b/crates/qroissant-core/src/extent.rs @@ -0,0 +1,518 @@ +//! Zero-allocation byte extent calculator for serialized q IPC values. +//! +//! Given a byte slice and an offset pointing to the start of a serialized q +//! value, [`value_byte_extent`] returns how many bytes that value occupies +//! without allocating memory or constructing a [`Value`]. This is used by +//! the parallel column decoder to split a table's column data into +//! independent sub-slices before dispatching them to worker threads. + +use crate::error::CoreError; +use crate::error::CoreResult; +use crate::protocol::Primitive; +use crate::protocol::Shape; +use crate::protocol::TypeCode; + +/// Returns the byte extent of a serialized q value starting at `bytes[offset..]`. +/// +/// The function reads only type codes, attributes, and lengths — it never +/// allocates or constructs a `Value`. For fixed-width vectors this is O(1); +/// for symbol vectors and nested structures it scans forward. +pub fn value_byte_extent(bytes: &[u8], offset: usize) -> CoreResult { + if offset >= bytes.len() { + return Err(CoreError::InvalidStructure(format!( + "extent: offset {offset} beyond buffer length {}", + bytes.len() + ))); + } + + let type_code = TypeCode::try_from(bytes[offset] as i8)?; + let shape = type_code.shape(); + + match shape { + Shape::Atom => atom_extent(bytes, offset, type_code), + Shape::Vector => vector_extent(bytes, offset, type_code), + Shape::List => list_extent(bytes, offset), + Shape::Dictionary => dictionary_extent(bytes, offset), + Shape::Table => table_extent(bytes, offset), + Shape::UnaryPrimitive => { + // type byte + opcode byte + check_available(bytes, offset, 2)?; + Ok(2) + } + Shape::Error => { + check_available(bytes, offset, 1)?; + let data_start = offset + 1; + let pos = bytes[data_start..] + .iter() + .position(|&b| b == 0) + .ok_or_else(|| { + CoreError::InvalidStructure(format!( + "extent: unterminated error string at offset {offset}" + )) + })?; + Ok(1 + pos + 1) + } + } +} + +/// Checks that at least `need` bytes are available from `offset`. +#[inline] +fn check_available(bytes: &[u8], offset: usize, need: usize) -> CoreResult<()> { + if offset + need > bytes.len() { + Err(CoreError::InvalidStructure(format!( + "extent: need {need} bytes at offset {offset}, but buffer length is {}", + bytes.len() + ))) + } else { + Ok(()) + } +} + +/// Reads an i32 length field at `bytes[offset..offset+4]` (little-endian). +#[inline] +fn read_len(bytes: &[u8], offset: usize) -> CoreResult { + check_available(bytes, offset, 4)?; + let len = i32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()); + if len < 0 { + return Err(CoreError::InvalidStructure(format!( + "extent: negative length {len} at offset {offset}" + ))); + } + Ok(len as usize) +} + +fn atom_extent(bytes: &[u8], offset: usize, type_code: TypeCode) -> CoreResult { + // 1 byte for type code + data bytes + let primitive = type_code + .primitive() + .ok_or(CoreError::InvalidTypeCode(type_code as i8))?; + + if let Some(width) = primitive.width() { + check_available(bytes, offset, 1 + width)?; + Ok(1 + width) + } else { + // Symbol atom: scan for null terminator + debug_assert_eq!(primitive, Primitive::Symbol); + let data_start = offset + 1; + let pos = bytes[data_start..] + .iter() + .position(|&b| b == 0) + .ok_or_else(|| { + CoreError::InvalidStructure(format!( + "extent: unterminated symbol atom at offset {offset}" + )) + })?; + // type byte + symbol bytes + null terminator + Ok(1 + pos + 1) + } +} + +fn vector_extent(bytes: &[u8], offset: usize, type_code: TypeCode) -> CoreResult { + // Header: 1 (type) + 1 (attribute) + 4 (length) = 6 bytes + const HEADER: usize = 6; + check_available(bytes, offset, HEADER)?; + let length = read_len(bytes, offset + 2)?; + + let primitive = type_code + .primitive() + .ok_or(CoreError::InvalidTypeCode(type_code as i8))?; + + if let Some(width) = primitive.width() { + let data_bytes = length + .checked_mul(width) + .ok_or(CoreError::LengthOverflow(length))?; + check_available(bytes, offset, HEADER + data_bytes)?; + Ok(HEADER + data_bytes) + } else { + // Symbol vector: scan through `length` null-terminated strings + debug_assert_eq!(primitive, Primitive::Symbol); + let mut scan = offset + HEADER; + for _ in 0..length { + let pos = bytes[scan..].iter().position(|&b| b == 0).ok_or_else(|| { + CoreError::InvalidStructure(format!( + "extent: unterminated symbol in vector at offset {scan}" + )) + })?; + scan += pos + 1; // skip past the null terminator + } + Ok(scan - offset) + } +} + +fn list_extent(bytes: &[u8], offset: usize) -> CoreResult { + // Header: 1 (type) + 1 (attribute) + 4 (length) = 6 bytes + const HEADER: usize = 6; + check_available(bytes, offset, HEADER)?; + let length = read_len(bytes, offset + 2)?; + + let mut scan = offset + HEADER; + for _ in 0..length { + let child_extent = value_byte_extent(bytes, scan)?; + scan += child_extent; + } + Ok(scan - offset) +} + +fn dictionary_extent(bytes: &[u8], offset: usize) -> CoreResult { + // 1 byte for type code (99 or 127), then keys value, then values value + check_available(bytes, offset, 1)?; + let keys_extent = value_byte_extent(bytes, offset + 1)?; + let values_extent = value_byte_extent(bytes, offset + 1 + keys_extent)?; + Ok(1 + keys_extent + values_extent) +} + +fn table_extent(bytes: &[u8], offset: usize) -> CoreResult { + // 1 byte type code + 1 byte attribute + inner dictionary + check_available(bytes, offset, 2)?; + let dict_extent = value_byte_extent(bytes, offset + 2)?; + Ok(2 + dict_extent) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::decode::decode_value; + use crate::encode::encode_value; + use crate::frame::Encoding; + use crate::value::*; + + /// Helper: encode a value, then verify extent equals encoded body length. + fn assert_extent_matches(value: &Value) { + let body = encode_value(value, Encoding::LittleEndian).unwrap(); + let extent = value_byte_extent(&body, 0).unwrap(); + assert_eq!( + extent, + body.len(), + "extent mismatch for {value:?}: expected {}, got {extent}", + body.len() + ); + } + + // -- Atoms -- + + #[test] + fn extent_boolean_atom() { + assert_extent_matches(&Value::Atom(Atom::Boolean(true))); + } + + #[test] + fn extent_byte_atom() { + assert_extent_matches(&Value::Atom(Atom::Byte(0x42))); + } + + #[test] + fn extent_short_atom() { + assert_extent_matches(&Value::Atom(Atom::Short(42))); + } + + #[test] + fn extent_int_atom() { + assert_extent_matches(&Value::Atom(Atom::Int(42))); + } + + #[test] + fn extent_long_atom() { + assert_extent_matches(&Value::Atom(Atom::Long(42))); + } + + #[test] + fn extent_real_atom() { + assert_extent_matches(&Value::Atom(Atom::Real(1.5))); + } + + #[test] + fn extent_float_atom() { + assert_extent_matches(&Value::Atom(Atom::Float(1.5))); + } + + #[test] + fn extent_char_atom() { + assert_extent_matches(&Value::Atom(Atom::Char(b'c'))); + } + + #[test] + fn extent_symbol_atom() { + assert_extent_matches(&Value::Atom(Atom::Symbol(bytes::Bytes::from_static( + b"hello", + )))); + } + + #[test] + fn extent_empty_symbol_atom() { + assert_extent_matches(&Value::Atom(Atom::Symbol(bytes::Bytes::from_static(b"")))); + } + + #[test] + fn extent_guid_atom() { + assert_extent_matches(&Value::Atom(Atom::Guid([0u8; 16]))); + } + + #[test] + fn extent_timestamp_atom() { + assert_extent_matches(&Value::Atom(Atom::Timestamp(1))); + } + + #[test] + fn extent_month_atom() { + assert_extent_matches(&Value::Atom(Atom::Month(1))); + } + + #[test] + fn extent_date_atom() { + assert_extent_matches(&Value::Atom(Atom::Date(1))); + } + + #[test] + fn extent_datetime_atom() { + assert_extent_matches(&Value::Atom(Atom::Datetime(1.5))); + } + + #[test] + fn extent_timespan_atom() { + assert_extent_matches(&Value::Atom(Atom::Timespan(1))); + } + + #[test] + fn extent_minute_atom() { + assert_extent_matches(&Value::Atom(Atom::Minute(1))); + } + + #[test] + fn extent_second_atom() { + assert_extent_matches(&Value::Atom(Atom::Second(1))); + } + + #[test] + fn extent_time_atom() { + assert_extent_matches(&Value::Atom(Atom::Time(1))); + } + + // -- Vectors -- + + #[test] + fn extent_int_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[1, 2, 3]), + ))); + } + + #[test] + fn extent_empty_int_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[]), + ))); + } + + #[test] + fn extent_symbol_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"alpha"), + bytes::Bytes::from_static(b"beta"), + ]), + ))); + } + + #[test] + fn extent_empty_symbol_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Symbol(vec![]), + ))); + } + + #[test] + fn extent_boolean_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Boolean(bytes::Bytes::from_static(&[1, 0, 1])), + ))); + } + + #[test] + fn extent_guid_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_guids(&[[0u8; 16], [1u8; 16]]), + ))); + } + + #[test] + fn extent_long_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i64s(&[1, 2, 3]), + ))); + } + + #[test] + fn extent_float_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_f64s(&[1.0, 2.0]), + ))); + } + + #[test] + fn extent_char_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Char(bytes::Bytes::from_static(b"hello")), + ))); + } + + #[test] + fn extent_byte_vector() { + assert_extent_matches(&Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Byte(bytes::Bytes::from(vec![1, 2, 3])), + ))); + } + + // -- Composites -- + + #[test] + fn extent_general_list() { + assert_extent_matches(&Value::List(List::new( + crate::protocol::Attribute::None, + vec![ + Value::Atom(Atom::Int(42)), + Value::Atom(Atom::Symbol(bytes::Bytes::from_static(b"ab"))), + ], + ))); + } + + #[test] + fn extent_empty_list() { + assert_extent_matches(&Value::List(List::new( + crate::protocol::Attribute::None, + vec![], + ))); + } + + #[test] + fn extent_dictionary() { + assert_extent_matches(&Value::Dictionary(Dictionary::new( + false, + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + ]), + )), + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[1, 2]), + )), + ))); + } + + #[test] + fn extent_sorted_dictionary() { + assert_extent_matches(&Value::Dictionary(Dictionary::new( + true, + Value::Vector(Vector::new( + crate::protocol::Attribute::Sorted, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + ]), + )), + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[1, 2]), + )), + ))); + } + + #[test] + fn extent_table() { + assert_extent_matches(&Value::Table(Table::new( + crate::protocol::Attribute::None, + vec![ + bytes::Bytes::from_static(b"sym"), + bytes::Bytes::from_static(b"px"), + ], + vec![ + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + ]), + )), + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[10, 20]), + )), + ], + ))); + } + + #[test] + fn extent_nested_list() { + assert_extent_matches(&Value::List(List::new( + crate::protocol::Attribute::None, + vec![ + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[1, 2, 3]), + )), + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[4, 5]), + )), + ], + ))); + } + + #[test] + fn extent_unary_primitive() { + let value = Value::UnaryPrimitive { opcode: 42 }; + assert_extent_matches(&value); + } + + /// Verify extent matches for every value encoded in a real roundtrip body. + #[test] + fn extent_matches_decode_consumption() { + // Encode a table, get the body, verify extent == body.len() + let table = Value::Table(Table::new( + crate::protocol::Attribute::None, + vec![ + bytes::Bytes::from_static(b"a"), + bytes::Bytes::from_static(b"b"), + bytes::Bytes::from_static(b"c"), + ], + vec![ + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_i32s(&[1, 2, 3]), + )), + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::Symbol(vec![ + bytes::Bytes::from_static(b"x"), + bytes::Bytes::from_static(b"y"), + bytes::Bytes::from_static(b"z"), + ]), + )), + Value::Vector(Vector::new( + crate::protocol::Attribute::None, + VectorData::from_f64s(&[1.0, 2.0, 3.0]), + )), + ], + )); + let body = encode_value(&table, Encoding::LittleEndian).unwrap(); + let extent = value_byte_extent(&body, 0).unwrap(); + assert_eq!(extent, body.len()); + + // Also verify roundtrip + let decoded = + decode_value(bytes::Bytes::from(body.clone()), Encoding::LittleEndian).unwrap(); + assert_eq!(decoded, table); + } +} diff --git a/crates/qroissant-core/src/frame.rs b/crates/qroissant-core/src/frame.rs new file mode 100644 index 0000000..e9f31f2 --- /dev/null +++ b/crates/qroissant-core/src/frame.rs @@ -0,0 +1,826 @@ +use std::io::Read; + +use crate::error::CoreError; +use crate::error::CoreResult; + +/// Fixed byte length of every q IPC message header. +pub const HEADER_LEN: usize = 8; + +/// Endianness marker stored in the first q IPC header byte. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum Encoding { + BigEndian, + #[default] + LittleEndian, +} + +impl Encoding { + fn decode_u32(self, bytes: [u8; 4]) -> u32 { + match self { + Self::BigEndian => u32::from_be_bytes(bytes), + Self::LittleEndian => u32::from_le_bytes(bytes), + } + } + + fn encode_u32(self, value: u32) -> [u8; 4] { + match self { + Self::BigEndian => value.to_be_bytes(), + Self::LittleEndian => value.to_le_bytes(), + } + } +} + +impl From for u8 { + fn from(value: Encoding) -> Self { + match value { + Encoding::BigEndian => 0, + Encoding::LittleEndian => 1, + } + } +} + +impl TryFrom for Encoding { + type Error = CoreError; + + fn try_from(value: u8) -> CoreResult { + match value { + 0 => Ok(Self::BigEndian), + 1 => Ok(Self::LittleEndian), + _ => Err(CoreError::InvalidEncoding(value)), + } + } +} + +/// q IPC message kind stored in the second q IPC header byte. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum MessageType { + #[default] + Asynchronous, + Synchronous, + Response, +} + +impl From for u8 { + fn from(value: MessageType) -> Self { + match value { + MessageType::Asynchronous => 0, + MessageType::Synchronous => 1, + MessageType::Response => 2, + } + } +} + +impl TryFrom for MessageType { + type Error = CoreError; + + fn try_from(value: u8) -> CoreResult { + match value { + 0 => Ok(Self::Asynchronous), + 1 => Ok(Self::Synchronous), + 2 => Ok(Self::Response), + _ => Err(CoreError::InvalidMessageType(value)), + } + } +} + +/// q IPC compression marker stored in the third q IPC header byte. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum Compression { + #[default] + Uncompressed, + Compressed, + CompressedLarge, +} + +impl From for u8 { + fn from(value: Compression) -> Self { + match value { + Compression::Uncompressed => 0, + Compression::Compressed => 1, + Compression::CompressedLarge => 2, + } + } +} + +impl TryFrom for Compression { + type Error = CoreError; + + fn try_from(value: u8) -> CoreResult { + match value { + 0 => Ok(Self::Uncompressed), + 1 => Ok(Self::Compressed), + 2 => Ok(Self::CompressedLarge), + _ => Err(CoreError::InvalidCompression(value)), + } + } +} + +/// Decoded q IPC message header. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct MessageHeader { + encoding: Encoding, + message_type: MessageType, + compression: Compression, + size: usize, +} + +impl MessageHeader { + /// Builds a validated message header. + pub fn new( + encoding: Encoding, + message_type: MessageType, + compression: Compression, + size: usize, + ) -> CoreResult { + if size < HEADER_LEN { + return Err(CoreError::InvalidMessageLength(size)); + } + + Ok(Self { + encoding, + message_type, + compression, + size, + }) + } + + /// Parses a message header from an exact 8-byte array. + pub fn from_bytes(bytes: [u8; HEADER_LEN]) -> CoreResult { + let encoding = Encoding::try_from(bytes[0])?; + let message_type = MessageType::try_from(bytes[1])?; + let compression = Compression::try_from(bytes[2])?; + let size = encoding.decode_u32(bytes[4..8].try_into().expect("fixed-size slice")) as usize; + Self::new(encoding, message_type, compression, size) + } + + /// Parses a message header from a byte slice. + pub fn parse(bytes: &[u8]) -> CoreResult { + let header: [u8; HEADER_LEN] = bytes + .get(..HEADER_LEN) + .ok_or(CoreError::TruncatedHeader { + actual: bytes.len(), + })? + .try_into() + .expect("header slice length already checked"); + Self::from_bytes(header) + } + + /// Serializes the header back to its q IPC byte representation. + pub fn to_bytes(self) -> CoreResult<[u8; HEADER_LEN]> { + let size = u32::try_from(self.size).map_err(|_| CoreError::LengthOverflow(self.size))?; + let mut bytes = [0_u8; HEADER_LEN]; + bytes[0] = self.encoding.into(); + bytes[1] = self.message_type.into(); + bytes[2] = self.compression.into(); + bytes[4..8].copy_from_slice(&self.encoding.encode_u32(size)); + Ok(bytes) + } + + pub fn encoding(self) -> Encoding { + self.encoding + } + + pub fn message_type(self) -> MessageType { + self.message_type + } + + pub fn compression(self) -> Compression { + self.compression + } + + pub fn size(self) -> usize { + self.size + } + + pub fn body_len(self) -> usize { + self.size - HEADER_LEN + } +} + +/// Borrowed validated q IPC frame. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct Frame<'a> { + header: MessageHeader, + body: &'a [u8], +} + +impl<'a> Frame<'a> { + /// Validates a full q IPC frame and returns borrowed header/body views. + pub fn parse(bytes: &'a [u8]) -> CoreResult { + let header = MessageHeader::parse(bytes)?; + if bytes.len() != header.size() { + return Err(CoreError::FrameLengthMismatch { + declared: header.size(), + actual: bytes.len(), + }); + } + + Ok(Self { + header, + body: &bytes[HEADER_LEN..], + }) + } + + pub fn header(self) -> MessageHeader { + self.header + } + + pub fn body(self) -> &'a [u8] { + self.body + } +} + +/// Decompresses a q IPC compressed body (follows the 8-byte header). +/// +/// The first 4 bytes of the compressed body are a size prefix encoding the +/// total decompressed message length including the 8-byte header. The +/// remaining bytes are the compressed payload using q's LZW-style algorithm: +/// a flag byte drives 8 decisions — bit clear emits a literal byte, bit set +/// emits a back-reference (2 fixed bytes + n extra bytes) via a 256-entry +/// XOR-keyed lookup table. +pub fn decompress_ipc_body(compressed: &[u8], encoding: Encoding) -> CoreResult> { + if compressed.len() < 4 { + return Err(CoreError::InvalidStructure(format!( + "compressed body must be at least 4 bytes for size prefix, got {}", + compressed.len() + ))); + } + + let size_with_header = match encoding { + Encoding::LittleEndian => { + i32::from_le_bytes(compressed[..4].try_into().expect("validated length")) + } + Encoding::BigEndian => { + i32::from_be_bytes(compressed[..4].try_into().expect("validated length")) + } + }; + if size_with_header < 8 { + return Err(CoreError::InvalidStructure(format!( + "compressed size prefix {size_with_header} is less than minimum header size 8" + ))); + } + let size = (size_with_header - 8) as usize; + + let mut decompressed = vec![0_u8; size]; + let mut aa = [0_i32; 256]; + let mut n = 0_usize; + let mut f = 0_usize; + let mut s = 0_usize; + let mut p = 0_usize; + let mut i = 0_usize; + let mut d = 4_usize; // skip the 4-byte size prefix + + while s < size { + if i == 0 { + if d >= compressed.len() { + return Err(CoreError::InvalidStructure( + "unexpected end of compressed data while reading flag byte".to_string(), + )); + } + f = compressed[d] as usize; + d += 1; + i = 1; + } + + if (f & i) != 0 { + // Back-reference: lookup key byte + extra count byte + if d + 2 > compressed.len() { + return Err(CoreError::InvalidStructure( + "insufficient data for back-reference (need 2 bytes)".to_string(), + )); + } + let mut r = aa[compressed[d] as usize] as usize; + d += 1; + + if r >= size { + return Err(CoreError::InvalidStructure(format!( + "back-reference start {r} exceeds decompressed buffer size {size}" + ))); + } + if s >= size { + return Err(CoreError::InvalidStructure(format!( + "write index {s} exceeds decompressed buffer size {size}" + ))); + } + decompressed[s] = decompressed[r]; + s += 1; + r += 1; + + if r >= size { + return Err(CoreError::InvalidStructure(format!( + "back-reference position {r} exceeds decompressed buffer size {size}" + ))); + } + if s >= size { + return Err(CoreError::InvalidStructure(format!( + "write index {s} exceeds decompressed buffer size {size}" + ))); + } + decompressed[s] = decompressed[r]; + s += 1; + r += 1; + + n = compressed[d] as usize; + d += 1; + + if r + n > size { + return Err(CoreError::InvalidStructure(format!( + "back-reference range {r}..{} exceeds decompressed buffer size {size}", + r + n + ))); + } + if s + n > size { + return Err(CoreError::InvalidStructure(format!( + "write range {s}..{} exceeds decompressed buffer size {size}", + s + n + ))); + } + for m in 0..n { + decompressed[s + m] = decompressed[r + m]; + } + } else { + // Literal byte + if d >= compressed.len() { + return Err(CoreError::InvalidStructure( + "unexpected end of compressed data while reading literal byte".to_string(), + )); + } + decompressed[s] = compressed[d]; + s += 1; + d += 1; + } + + // Update the XOR lookup table with newly emitted bytes + while p < s.saturating_sub(1) { + aa[(decompressed[p] ^ decompressed[p + 1]) as usize] = p as i32; + p += 1; + } + + if (f & i) != 0 { + s += n; + p = s; + } + + i *= 2; + if i == 256 { + i = 0; + } + } + + Ok(decompressed) +} + +/// Serializes a q-encoded body as a complete q IPC message. +/// +/// This mirrors the current rewrite contract: qroissant only emits +/// little-endian, uncompressed frames for now. +pub fn serialize_body_as_message( + body: &[u8], + encoding: Encoding, + message_type: MessageType, + compression: Compression, +) -> CoreResult> { + if encoding != Encoding::LittleEndian { + return Err(CoreError::UnsupportedEndianness(encoding)); + } + if compression != Compression::Uncompressed { + return Err(CoreError::UnsupportedCompression(compression)); + } + + let size = HEADER_LEN + .checked_add(body.len()) + .ok_or(CoreError::LengthOverflow(usize::MAX))?; + let header = MessageHeader::new(encoding, message_type, compression, size)?; + let mut payload = Vec::with_capacity(size); + payload.extend_from_slice(&header.to_bytes()?); + payload.extend_from_slice(body); + Ok(payload) +} + +/// Reads the total q IPC frame length from an 8-byte header. +pub fn read_message_length(header: &[u8; HEADER_LEN]) -> CoreResult { + Ok(MessageHeader::from_bytes(*header)?.size()) +} + +/// Reads one complete q IPC frame from an IO stream. +pub fn read_frame(reader: &mut R) -> CoreResult> { + let mut header = [0_u8; HEADER_LEN]; + reader.read_exact(&mut header)?; + let frame_len = read_message_length(&header)?; + let mut frame = vec![0_u8; frame_len]; + frame[..HEADER_LEN].copy_from_slice(&header); + reader.read_exact(&mut frame[HEADER_LEN..])?; + Ok(frame) +} + +/// Incremental q IPC decompressor that can be fed compressed bytes as they +/// arrive from the network, overlapping I/O with decompression work. +/// +/// The q LZW algorithm reads compressed input forward-only — back-references +/// target the *output* buffer, not the input. This means we can process +/// compressed bytes as soon as they arrive without buffering the entire +/// compressed payload first. +/// +/// # Usage +/// +/// ```ignore +/// let mut dec = StreamingDecompressor::new(size_prefix, Encoding::LittleEndian)?; +/// while !dec.is_complete() { +/// let chunk = read_from_network()?; +/// dec.feed(&chunk)?; +/// } +/// let body = dec.finish()?; +/// ``` +pub struct StreamingDecompressor { + decompressed: Vec, + aa: [i32; 256], + compressed_buf: Vec, + d: usize, + s: usize, + p: usize, + f: usize, + i: usize, + size: usize, + read_ptr: usize, +} + +impl StreamingDecompressor { + /// Creates a new streaming decompressor from the 4-byte size prefix + /// (the first 4 bytes of the compressed body after the 8-byte header). + pub fn new(size_prefix: [u8; 4], encoding: Encoding) -> CoreResult { + let size_with_header = match encoding { + Encoding::LittleEndian => i32::from_le_bytes(size_prefix), + Encoding::BigEndian => i32::from_be_bytes(size_prefix), + }; + if size_with_header < 8 { + return Err(CoreError::InvalidStructure(format!( + "compressed size prefix {size_with_header} is less than minimum header size 8" + ))); + } + let size = (size_with_header - 8) as usize; + + Ok(Self { + decompressed: vec![0_u8; size], + aa: [0_i32; 256], + compressed_buf: Vec::new(), + d: 0, + s: 0, + p: 0, + f: 0, + i: 0, + size, + read_ptr: 0, + }) + } + + pub fn feed(&mut self, chunk: &[u8]) -> CoreResult { + self.compressed_buf.extend_from_slice(chunk); + let prev_s = self.s; + + while self.s < self.size { + if self.i == 0 { + if self.d >= self.compressed_buf.len() { + break; + } + self.f = self.compressed_buf[self.d] as usize; + self.d += 1; + self.i = 1; + } + + let is_backref = (self.f & self.i) != 0; + let mut n = 0; + + if is_backref { + if self.d + 2 > self.compressed_buf.len() { + break; + } + let mut r = self.aa[self.compressed_buf[self.d] as usize] as usize; + self.d += 1; + if r >= self.size || self.s + 2 > self.size { + return Err(CoreError::InvalidStructure( + "backref out of bounds".to_string(), + )); + } + self.decompressed[self.s] = self.decompressed[r]; + self.s += 1; + r += 1; + + if r >= self.size || self.s + 1 > self.size { + return Err(CoreError::InvalidStructure( + "backref out of bounds".to_string(), + )); + } + self.decompressed[self.s] = self.decompressed[r]; + self.s += 1; + r += 1; + + n = self.compressed_buf[self.d] as usize; + self.d += 1; + if r + n > self.size || self.s + n > self.size { + return Err(CoreError::InvalidStructure( + "backref out of bounds".to_string(), + )); + } + for m in 0..n { + self.decompressed[self.s + m] = self.decompressed[r + m]; + } + } else { + if self.d >= self.compressed_buf.len() { + break; + } + self.decompressed[self.s] = self.compressed_buf[self.d]; + self.s += 1; + self.d += 1; + } + + // Sync lookup table + while self.p < self.s.saturating_sub(1) { + self.aa[(self.decompressed[self.p] ^ self.decompressed[self.p + 1]) as usize] = + self.p as i32; + self.p += 1; + } + + if is_backref { + self.s += n; + self.p = self.s; + } + + self.i *= 2; + if self.i == 256 { + self.i = 0; + } + } + + // Keep memory usage in check by draining processed bytes + if self.d > 0 { + self.compressed_buf.drain(0..self.d); + self.d = 0; + } + + Ok(self.s - prev_s) + } + + /// Returns `true` when decompression is complete. + pub fn is_complete(&self) -> bool { + self.s >= self.size + } + + /// Current number of decompressed bytes available. + pub fn decompressed_len(&self) -> usize { + self.s + } + + /// Number of decompressed bytes that have not yet been read. + pub fn unread_len(&self) -> usize { + self.s - self.read_ptr + } + + /// Returns a slice of the next available decompressed bytes. + pub fn next_chunk(&self) -> &[u8] { + &self.decompressed[self.read_ptr..self.s] + } + + /// Advances the read pointer by `len` bytes. + pub fn consume(&mut self, len: usize) { + self.read_ptr = (self.read_ptr + len).min(self.s); + } + + /// Total expected decompressed size. + pub fn total_size(&self) -> usize { + self.size + } + + /// Borrows the decompressed output produced so far. + pub fn decompressed(&self) -> &[u8] { + &self.decompressed[..self.s] + } + + /// Consumes the decompressor and returns the completed output buffer. + /// + /// Returns an error if decompression is not yet complete. + pub fn finish(self) -> CoreResult> { + if !self.is_complete() { + return Err(CoreError::InvalidStructure(format!( + "streaming decompress: incomplete — {}/{} bytes decompressed", + self.s, self.size + ))); + } + Ok(self.decompressed) + } +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use super::*; + + #[test] + fn encoding_round_trips_from_u8() { + assert_eq!(Encoding::try_from(0).unwrap(), Encoding::BigEndian); + assert_eq!(Encoding::try_from(1).unwrap(), Encoding::LittleEndian); + assert!(matches!( + Encoding::try_from(9), + Err(CoreError::InvalidEncoding(9)) + )); + } + + #[test] + fn compression_supports_compressed_large() { + assert_eq!(Compression::try_from(0).unwrap(), Compression::Uncompressed); + assert_eq!(Compression::try_from(1).unwrap(), Compression::Compressed); + assert_eq!( + Compression::try_from(2).unwrap(), + Compression::CompressedLarge + ); + } + + #[test] + fn header_parses_little_endian_payloads() { + let header = MessageHeader::from_bytes([1, 2, 2, 0, 24, 0, 0, 0]).unwrap(); + + assert_eq!(header.encoding(), Encoding::LittleEndian); + assert_eq!(header.message_type(), MessageType::Response); + assert_eq!(header.compression(), Compression::CompressedLarge); + assert_eq!(header.size(), 24); + assert_eq!(header.body_len(), 16); + } + + #[test] + fn header_parses_big_endian_lengths() { + let header = MessageHeader::from_bytes([0, 1, 0, 0, 0, 0, 0, 16]).unwrap(); + + assert_eq!(header.encoding(), Encoding::BigEndian); + assert_eq!(header.message_type(), MessageType::Synchronous); + assert_eq!(header.size(), 16); + } + + #[test] + fn header_rejects_lengths_smaller_than_header() { + assert!(matches!( + MessageHeader::from_bytes([1, 2, 0, 0, 7, 0, 0, 0]), + Err(CoreError::InvalidMessageLength(7)) + )); + } + + #[test] + fn header_to_bytes_round_trips() { + let header = MessageHeader::new( + Encoding::LittleEndian, + MessageType::Response, + Compression::Compressed, + 64, + ) + .unwrap(); + + let bytes = header.to_bytes().unwrap(); + assert_eq!(MessageHeader::from_bytes(bytes).unwrap(), header); + } + + #[test] + fn frame_parse_validates_declared_length() { + let frame = [1, 2, 0, 0, 10, 0, 0, 0, 42, 43]; + let parsed = Frame::parse(&frame).unwrap(); + + assert_eq!(parsed.header().size(), 10); + assert_eq!(parsed.body(), &[42, 43]); + } + + #[test] + fn frame_parse_rejects_length_mismatch() { + let frame = [1, 2, 0, 0, 11, 0, 0, 0, 42, 43]; + assert!(matches!( + Frame::parse(&frame), + Err(CoreError::FrameLengthMismatch { + declared: 11, + actual: 10 + }) + )); + } + + #[test] + fn serialize_body_wraps_uncompressed_little_endian_body() { + let payload = serialize_body_as_message( + &[10, 20, 30], + Encoding::LittleEndian, + MessageType::Synchronous, + Compression::Uncompressed, + ) + .unwrap(); + + assert_eq!(payload, vec![1, 1, 0, 0, 11, 0, 0, 0, 10, 20, 30]); + } + + #[test] + fn serialize_body_rejects_big_endian_for_now() { + assert!(matches!( + serialize_body_as_message( + &[1], + Encoding::BigEndian, + MessageType::Asynchronous, + Compression::Uncompressed, + ), + Err(CoreError::UnsupportedEndianness(Encoding::BigEndian)) + )); + } + + #[test] + fn serialize_body_rejects_compressed_frames_for_now() { + assert!(matches!( + serialize_body_as_message( + &[1], + Encoding::LittleEndian, + MessageType::Asynchronous, + Compression::CompressedLarge, + ), + Err(CoreError::UnsupportedCompression( + Compression::CompressedLarge + )) + )); + } + + #[test] + fn read_frame_reads_complete_payload() { + let mut cursor = Cursor::new(vec![1, 2, 0, 0, 10, 0, 0, 0, 42, 43]); + let frame = read_frame(&mut cursor).unwrap(); + + assert_eq!(frame, vec![1, 2, 0, 0, 10, 0, 0, 0, 42, 43]); + } + + // ----------------------------------------------------------------------- + // StreamingDecompressor tests + // ----------------------------------------------------------------------- + + /// Helper: compress a body using the batch decompressor, then verify the + /// streaming decompressor produces identical output. + /// + /// Since we don't have an encoder for compression, we test by creating + /// compressed data that the batch decompressor can handle and verifying + /// the streaming variant matches. We use decompress_ipc_body as the + /// reference implementation. + fn assert_streaming_matches_batch(compressed_body: &[u8]) { + let batch_result = decompress_ipc_body(compressed_body, Encoding::LittleEndian).unwrap(); + + // Feed all at once + let size_prefix: [u8; 4] = compressed_body[..4].try_into().unwrap(); + let mut dec = StreamingDecompressor::new(size_prefix, Encoding::LittleEndian).unwrap(); + dec.feed(&compressed_body[4..]).unwrap(); + assert!(dec.is_complete()); + let streaming_result = dec.finish().unwrap(); + assert_eq!(streaming_result, batch_result, "all-at-once mismatch"); + + // Feed byte-by-byte + let mut dec = StreamingDecompressor::new(size_prefix, Encoding::LittleEndian).unwrap(); + for &byte in &compressed_body[4..] { + dec.feed(&[byte]).unwrap(); + } + assert!(dec.is_complete()); + let streaming_result = dec.finish().unwrap(); + assert_eq!(streaming_result, batch_result, "byte-by-byte mismatch"); + } + + #[test] + fn streaming_decompressor_empty_body() { + // Size prefix says 8 bytes total (header only), so decompressed size = 0 + let size_prefix = 8_i32.to_le_bytes(); + let dec = StreamingDecompressor::new(size_prefix, Encoding::LittleEndian).unwrap(); + // No data to feed — already complete + assert!(dec.is_complete()); + assert_eq!(dec.decompressed_len(), 0); + let result = dec.finish().unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn streaming_decompressor_rejects_small_size() { + let size_prefix = 4_i32.to_le_bytes(); + assert!(StreamingDecompressor::new(size_prefix, Encoding::LittleEndian).is_err()); + } + + #[test] + fn streaming_decompressor_finish_before_complete() { + // Size says 16 bytes decompressed (24 total - 8 header) + let size_prefix = 24_i32.to_le_bytes(); + let dec = StreamingDecompressor::new(size_prefix, Encoding::LittleEndian).unwrap(); + assert!(!dec.is_complete()); + assert!(dec.finish().is_err()); + } + + #[test] + fn streaming_decompressor_literal_only() { + // Build a compressed payload that's all literals (no back-references). + // Flag byte 0x00 means all 8 bits are "literal". + // For a 3-byte decompressed output: + // size_prefix = (8 + 3) = 11 + // compressed: [flag=0x00] [lit1] [lit2] [lit3] + let size_prefix = 11_i32.to_le_bytes(); + let mut compressed = Vec::new(); + compressed.extend_from_slice(&size_prefix); + compressed.push(0x00); // flag: 8 literals + compressed.push(0x41); // 'A' + compressed.push(0x42); // 'B' + compressed.push(0x43); // 'C' + + assert_streaming_matches_batch(&compressed); + } +} diff --git a/crates/qroissant-core/src/lib.rs b/crates/qroissant-core/src/lib.rs new file mode 100644 index 0000000..1ce6637 --- /dev/null +++ b/crates/qroissant-core/src/lib.rs @@ -0,0 +1,61 @@ +//! q IPC protocol and value semantics for qroissant. +//! +//! This crate provides the core building blocks for encoding, decoding, and +//! representing q/kdb+ IPC messages: +//! +//! - **`protocol`** — type codes, primitives, shapes, and attributes that +//! define the q wire format. +//! - **`value`** — the `Value` enum and its variants (`Atom`, `Vector`, +//! `List`, `Dictionary`, `Table`) that model q data in Rust. +//! - **`frame`** — message framing, header parsing, compression, and the +//! `StreamingDecompressor` for incremental LZW decompression. +//! - **`decode`** — synchronous message and value decoding with optional +//! parallel column decode via rayon. +//! - **`encode`** — serialisation of `Value` trees into q IPC byte frames. +//! - **`pipelined`** — asynchronous (`tokio::io::AsyncRead`) value decoder +//! for streaming use cases. +//! - **`extent`** — zero-allocation byte extent scanning used to locate +//! column boundaries for parallel decode. + +pub mod decode; +pub mod encode; +pub mod error; +pub mod extent; +pub mod frame; +pub mod pipelined; +pub mod protocol; +pub mod value; + +pub use decode::DecodeOptions; +pub use decode::DecodedMessage; +pub use decode::decode_message; +pub use decode::decode_message_with_options; +pub use decode::decode_value; +pub use decode::decode_value_with_options; +pub use encode::encode_message; +pub use encode::encode_value; +pub use error::CoreError; +pub use error::CoreResult; +pub use extent::value_byte_extent; +pub use frame::Compression; +pub use frame::Encoding; +pub use frame::Frame; +pub use frame::HEADER_LEN; +pub use frame::MessageHeader; +pub use frame::MessageType; +pub use frame::StreamingDecompressor; +pub use frame::read_frame; +pub use frame::read_message_length; +pub use frame::serialize_body_as_message; +pub use protocol::Attribute; +pub use protocol::Primitive; +pub use protocol::Shape; +pub use protocol::TypeCode; +pub use protocol::ValueType; +pub use value::Atom; +pub use value::Dictionary; +pub use value::List; +pub use value::Table; +pub use value::Value; +pub use value::Vector; +pub use value::VectorData; diff --git a/crates/qroissant-core/src/pipelined.rs b/crates/qroissant-core/src/pipelined.rs new file mode 100644 index 0000000..a18a4e4 --- /dev/null +++ b/crates/qroissant-core/src/pipelined.rs @@ -0,0 +1,390 @@ +use futures::future::BoxFuture; +use futures::future::FutureExt; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; + +use crate::decode::extract_columns; +use crate::decode::extract_symbol_names; +use crate::error::CoreError; +use crate::error::CoreResult; +use crate::frame::Encoding; +use crate::protocol::Attribute; +use crate::protocol::Primitive; +use crate::protocol::TypeCode; +use crate::value::Atom; +use crate::value::Dictionary; +use crate::value::List; +use crate::value::Table; +use crate::value::Value; +use crate::value::Vector; +use crate::value::VectorData; + +/// Asynchronous reader for q value components. +/// +/// Wraps an `AsyncRead` source and provides async methods to read +/// primitive types and byte chunks, allowing the decoder to wait +/// for data without blocking. +/// +/// Only little-endian payloads are supported (matching the rest of qroissant). +pub struct PipelinedReader { + reader: R, +} + +impl PipelinedReader { + /// Creates a new pipelined reader. + /// + /// Returns `UnsupportedEndianness` for big-endian payloads, matching + /// the behaviour of `decode_value()` and `decode_message()`. + pub fn new(reader: R, encoding: Encoding) -> CoreResult { + if encoding != Encoding::LittleEndian { + return Err(CoreError::UnsupportedEndianness(encoding)); + } + Ok(Self { reader }) + } + + pub async fn read_u8(&mut self) -> CoreResult { + let mut buf = [0_u8; 1]; + self.reader.read_exact(&mut buf).await?; + Ok(buf[0]) + } + + pub async fn read_i8(&mut self) -> CoreResult { + Ok(self.read_u8().await? as i8) + } + + pub async fn read_i16(&mut self) -> CoreResult { + let mut buf = [0_u8; 2]; + self.reader.read_exact(&mut buf).await?; + Ok(i16::from_le_bytes(buf)) + } + + pub async fn read_i32(&mut self) -> CoreResult { + let mut buf = [0_u8; 4]; + self.reader.read_exact(&mut buf).await?; + Ok(i32::from_le_bytes(buf)) + } + + pub async fn read_i64(&mut self) -> CoreResult { + let mut buf = [0_u8; 8]; + self.reader.read_exact(&mut buf).await?; + Ok(i64::from_le_bytes(buf)) + } + + pub async fn read_f32(&mut self) -> CoreResult { + let mut buf = [0_u8; 4]; + self.reader.read_exact(&mut buf).await?; + Ok(f32::from_le_bytes(buf)) + } + + pub async fn read_f64(&mut self) -> CoreResult { + let mut buf = [0_u8; 8]; + self.reader.read_exact(&mut buf).await?; + Ok(f64::from_le_bytes(buf)) + } + + pub async fn read_guid(&mut self) -> CoreResult<[u8; 16]> { + let mut buf = [0_u8; 16]; + self.reader.read_exact(&mut buf).await?; + Ok(buf) + } + + pub async fn read_length(&mut self) -> CoreResult { + let length = self.read_i32().await?; + usize::try_from(length).map_err(|_| CoreError::InvalidCollectionLength(length)) + } + + pub async fn read_bytes(&mut self, len: usize) -> CoreResult { + let mut buf = vec![0_u8; len]; + self.reader.read_exact(&mut buf).await?; + Ok(bytes::Bytes::from(buf)) + } + + /// Reads a null-terminated symbol. + /// + /// Reads one byte at a time until a null terminator is found. + /// In practice the underlying reader is buffered (e.g. `BufReader` + /// or `DecompressingReader` with an 8 KB buffer), so single-byte + /// `read_exact` calls are cheap — they copy from the user-space buffer + /// without issuing a syscall. + pub async fn read_symbol(&mut self) -> CoreResult { + let mut buf = Vec::new(); + loop { + let b = self.read_u8().await?; + if b == 0 { + return Ok(bytes::Bytes::from(buf)); + } + buf.push(b); + } + } + + pub async fn read_vec( + &mut self, + count: usize, + ) -> CoreResult> { + let _byte_len = count + .checked_mul(std::mem::size_of::()) + .ok_or(CoreError::LengthOverflow(count))?; + let mut values = vec![T::zeroed(); count]; + let dst: &mut [u8] = bytemuck::cast_slice_mut(&mut values); + self.reader.read_exact(dst).await?; + Ok(values) + } +} + +pub async fn decode_value_async( + reader: &mut PipelinedReader, +) -> CoreResult { + decode_inner_async(reader).await +} + +fn decode_inner_async<'a, R: AsyncRead + Unpin + Send>( + reader: &'a mut PipelinedReader, +) -> BoxFuture<'a, CoreResult> { + async move { + let type_code_byte = reader.read_i8().await?; + let type_code = TypeCode::try_from(type_code_byte)?; + match type_code.shape() { + crate::protocol::Shape::Atom => { + let primitive = type_code + .primitive() + .ok_or(CoreError::InvalidTypeCode(type_code.into()))?; + Ok(Value::Atom(decode_atom_async(reader, primitive).await?)) + } + crate::protocol::Shape::Vector => { + let primitive = type_code + .primitive() + .ok_or(CoreError::InvalidTypeCode(type_code.into()))?; + let attribute = Attribute::try_from(reader.read_i8().await?)?; + let length = reader.read_length().await?; + Ok(Value::Vector( + decode_vector_async(reader, primitive, attribute, length).await?, + )) + } + crate::protocol::Shape::List => { + let attribute = Attribute::try_from(reader.read_i8().await?)?; + let length = reader.read_length().await?; + let mut values = Vec::with_capacity(length); + for _ in 0..length { + values.push(decode_inner_async(reader).await?); + } + Ok(Value::List(List::new(attribute, values))) + } + crate::protocol::Shape::Dictionary => { + let sorted = type_code == TypeCode::SortedDictionary; + let keys = decode_inner_async(reader).await?; + let values = decode_inner_async(reader).await?; + let dict = Dictionary::new(sorted, keys, values); + dict.validate()?; + Ok(Value::Dictionary(dict)) + } + crate::protocol::Shape::Table => { + let attribute = Attribute::try_from(reader.read_i8().await?)?; + let dict_value = decode_inner_async(reader).await?; + match dict_value { + Value::Dictionary(dict) => { + let names = extract_symbol_names(dict.keys())?; + let columns = extract_columns(dict.values())?; + let table = Table::new(attribute, names, columns); + table.validate()?; + Ok(Value::Table(table)) + } + _ => Err(CoreError::InvalidStructure( + "q table payload must contain a dictionary body".to_string(), + )), + } + } + crate::protocol::Shape::UnaryPrimitive => Ok(Value::UnaryPrimitive { + opcode: reader.read_i8().await?, + }), + crate::protocol::Shape::Error => { + let error_msg = reader.read_symbol().await?; + Err(CoreError::QRuntime( + String::from_utf8_lossy(&error_msg).into(), + )) + } + } + } + .boxed() +} + +async fn decode_atom_async( + reader: &mut PipelinedReader, + primitive: Primitive, +) -> CoreResult { + Ok(match primitive { + Primitive::Boolean => Atom::Boolean(reader.read_u8().await? != 0), + Primitive::Guid => Atom::Guid(reader.read_guid().await?), + Primitive::Byte => Atom::Byte(reader.read_u8().await?), + Primitive::Short => Atom::Short(reader.read_i16().await?), + Primitive::Int => Atom::Int(reader.read_i32().await?), + Primitive::Long => Atom::Long(reader.read_i64().await?), + Primitive::Real => Atom::Real(reader.read_f32().await?), + Primitive::Float => Atom::Float(reader.read_f64().await?), + Primitive::Char => Atom::Char(reader.read_u8().await?), + Primitive::Symbol => Atom::Symbol(reader.read_symbol().await?), + Primitive::Timestamp => Atom::Timestamp(reader.read_i64().await?), + Primitive::Month => Atom::Month(reader.read_i32().await?), + Primitive::Date => Atom::Date(reader.read_i32().await?), + Primitive::Datetime => Atom::Datetime(reader.read_f64().await?), + Primitive::Timespan => Atom::Timespan(reader.read_i64().await?), + Primitive::Minute => Atom::Minute(reader.read_i32().await?), + Primitive::Second => Atom::Second(reader.read_i32().await?), + Primitive::Time => Atom::Time(reader.read_i32().await?), + Primitive::Mixed => unreachable!("mixed values are not encoded as atoms"), + }) +} + +async fn decode_vector_async( + reader: &mut PipelinedReader, + primitive: Primitive, + attribute: Attribute, + length: usize, +) -> CoreResult { + let data = match primitive { + Primitive::Boolean => VectorData::Boolean(reader.read_bytes(length).await?), + Primitive::Guid => { + let byte_len = length + .checked_mul(16) + .ok_or(CoreError::LengthOverflow(length))?; + VectorData::Guid(reader.read_bytes(byte_len).await?) + } + Primitive::Byte => VectorData::Byte(reader.read_bytes(length).await?), + Primitive::Short => VectorData::Short(reader.read_bytes(length * 2).await?), + Primitive::Int => VectorData::Int(reader.read_bytes(length * 4).await?), + Primitive::Long => VectorData::Long(reader.read_bytes(length * 8).await?), + Primitive::Real => VectorData::Real(reader.read_bytes(length * 4).await?), + Primitive::Float => VectorData::Float(reader.read_bytes(length * 8).await?), + Primitive::Char => VectorData::Char(reader.read_bytes(length).await?), + Primitive::Symbol => { + let mut values = Vec::with_capacity(length); + for _ in 0..length { + values.push(reader.read_symbol().await?); + } + VectorData::Symbol(values) + } + Primitive::Timestamp => VectorData::Timestamp(reader.read_bytes(length * 8).await?), + Primitive::Month => VectorData::Month(reader.read_bytes(length * 4).await?), + Primitive::Date => VectorData::Date(reader.read_bytes(length * 4).await?), + Primitive::Datetime => VectorData::Datetime(reader.read_bytes(length * 8).await?), + Primitive::Timespan => VectorData::Timespan(reader.read_bytes(length * 8).await?), + Primitive::Minute => VectorData::Minute(reader.read_bytes(length * 4).await?), + Primitive::Second => VectorData::Second(reader.read_bytes(length * 4).await?), + Primitive::Time => VectorData::Time(reader.read_bytes(length * 4).await?), + Primitive::Mixed => unreachable!("mixed values are not encoded as vectors"), + }; + + Ok(Vector::new(attribute, data)) +} + +#[cfg(test)] +mod tests { + use std::io::Cursor; + + use super::*; + + #[tokio::test] + async fn test_decode_atom_async() -> CoreResult<()> { + let mut data = Vec::new(); + data.push(TypeCode::IntAtom as u8); + data.extend_from_slice(&42_i32.to_le_bytes()); + + let mut reader = PipelinedReader::new(Cursor::new(data), Encoding::LittleEndian).unwrap(); + let value = decode_value_async(&mut reader).await?; + + assert_eq!(value, Value::Atom(Atom::Int(42))); + Ok(()) + } + + #[tokio::test] + async fn test_decode_vector_async() -> CoreResult<()> { + let mut data = Vec::new(); + data.push(TypeCode::IntVector as u8); + data.push(0_u8); // attribute None + data.extend_from_slice(&2_i32.to_le_bytes()); // length 2 + data.extend_from_slice(&10_i32.to_le_bytes()); + data.extend_from_slice(&20_i32.to_le_bytes()); + + let mut reader = PipelinedReader::new(Cursor::new(data), Encoding::LittleEndian).unwrap(); + let value = decode_value_async(&mut reader).await?; + + match &value { + Value::Vector(vector) => { + assert_eq!(vector.data().as_i32_slice(), &[10, 20]); + } + _ => panic!("Expected Vector, got {:?}", value), + } + Ok(()) + } + + #[tokio::test] + async fn test_decode_table_async() -> CoreResult<()> { + let mut data = Vec::new(); + data.push(TypeCode::Table as u8); + data.push(0_u8); // attribute None + + // Dictionary prefix + data.push(TypeCode::Dictionary as u8); + + // Dictionary (keys) + data.push(TypeCode::SymbolVector as u8); + data.push(0_u8); // attribute None + data.extend_from_slice(&1_i32.to_le_bytes()); // 1 column name + data.extend_from_slice(b"col1\0"); + + // Dictionary (values) + data.push(TypeCode::GeneralList as u8); + data.push(0_u8); // attribute None + data.extend_from_slice(&1_i32.to_le_bytes()); // 1 column + + // Column 1: Int Vector [100, 200] + data.push(TypeCode::IntVector as u8); + data.push(0_u8); + data.extend_from_slice(&2_i32.to_le_bytes()); + data.extend_from_slice(&100_i32.to_le_bytes()); + data.extend_from_slice(&200_i32.to_le_bytes()); + + let mut reader = PipelinedReader::new(Cursor::new(data), Encoding::LittleEndian).unwrap(); + let value = decode_value_async(&mut reader).await?; + + match &value { + Value::Table(table) => { + assert_eq!(table.num_columns(), 1); + assert_eq!(&table.column_names()[0][..], b"col1"); + match &table.columns()[0] { + Value::Vector(v) => { + assert_eq!(v.data().as_i32_slice(), &[100, 200]); + } + _ => panic!("Expected Vector"), + } + } + _ => panic!("Expected Table, got {:?}", value), + } + Ok(()) + } + + #[tokio::test] + async fn test_rejects_big_endian() { + let result = PipelinedReader::new(Cursor::new(vec![]), Encoding::BigEndian); + assert!(matches!( + result, + Err(CoreError::UnsupportedEndianness(Encoding::BigEndian)) + )); + } + + #[tokio::test] + async fn test_negative_length_gives_proper_error() -> CoreResult<()> { + let mut data = Vec::new(); + data.push(TypeCode::IntVector as u8); + data.push(0_u8); // attribute None + data.extend_from_slice(&(-1_i32).to_le_bytes()); // negative length + + let mut reader = PipelinedReader::new(Cursor::new(data), Encoding::LittleEndian).unwrap(); + let err = decode_value_async(&mut reader).await.unwrap_err(); + assert!( + matches!(err, CoreError::InvalidCollectionLength(-1)), + "expected InvalidCollectionLength(-1), got {:?}", + err + ); + Ok(()) + } +} diff --git a/crates/qroissant-core/src/protocol.rs b/crates/qroissant-core/src/protocol.rs new file mode 100644 index 0000000..7b9b3f5 --- /dev/null +++ b/crates/qroissant-core/src/protocol.rs @@ -0,0 +1,373 @@ +use crate::error::CoreError; +use crate::error::CoreResult; + +/// q attribute attached to vectors, lists, and tables. +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] +pub enum Attribute { + #[default] + None, + Sorted, + Unique, + Parted, + Grouped, +} + +impl From for i8 { + fn from(value: Attribute) -> Self { + match value { + Attribute::None => 0, + Attribute::Sorted => 1, + Attribute::Unique => 2, + Attribute::Parted => 3, + Attribute::Grouped => 4, + } + } +} + +impl TryFrom for Attribute { + type Error = CoreError; + + fn try_from(value: i8) -> CoreResult { + match value { + 0 => Ok(Self::None), + 1 => Ok(Self::Sorted), + 2 => Ok(Self::Unique), + 3 => Ok(Self::Parted), + 4 => Ok(Self::Grouped), + _ => Err(CoreError::InvalidAttribute(value)), + } + } +} + +/// q primitive domain shared by atoms and homogeneous vectors. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Primitive { + Boolean, + Guid, + Byte, + Short, + Int, + Long, + Real, + Float, + Char, + Symbol, + Timestamp, + Month, + Date, + Datetime, + Timespan, + Minute, + Second, + Time, + Mixed, +} + +impl Primitive { + /// Fixed-width byte width for primitives that have one on the wire. + pub fn width(self) -> Option { + match self { + Self::Boolean | Self::Byte | Self::Char => Some(1), + Self::Short => Some(2), + Self::Int + | Self::Real + | Self::Month + | Self::Date + | Self::Minute + | Self::Second + | Self::Time => Some(4), + Self::Long | Self::Float | Self::Timestamp | Self::Datetime | Self::Timespan => Some(8), + Self::Guid => Some(16), + Self::Symbol | Self::Mixed => None, + } + } +} + +/// Top-level q structural shape. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Shape { + Atom, + Vector, + List, + Dictionary, + Table, + UnaryPrimitive, + Error, +} + +/// Complete q type descriptor for a decoded value. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub struct ValueType { + pub primitive: Option, + pub shape: Shape, + pub attribute: Option, + pub sorted: Option, +} + +impl ValueType { + pub fn atom(primitive: Primitive) -> Self { + Self { + primitive: Some(primitive), + shape: Shape::Atom, + attribute: None, + sorted: None, + } + } + + pub fn vector(primitive: Primitive, attribute: Attribute) -> Self { + Self { + primitive: Some(primitive), + shape: Shape::Vector, + attribute: Some(attribute), + sorted: None, + } + } + + pub fn list(attribute: Attribute) -> Self { + Self { + primitive: Some(Primitive::Mixed), + shape: Shape::List, + attribute: Some(attribute), + sorted: None, + } + } + + pub fn dictionary(sorted: bool) -> Self { + Self { + primitive: None, + shape: Shape::Dictionary, + attribute: None, + sorted: Some(sorted), + } + } + + pub fn table(attribute: Attribute) -> Self { + Self { + primitive: None, + shape: Shape::Table, + attribute: Some(attribute), + sorted: None, + } + } + + pub fn unary_primitive() -> Self { + Self { + primitive: None, + shape: Shape::UnaryPrimitive, + attribute: None, + sorted: None, + } + } +} + +/// Raw q IPC type code. +#[repr(i8)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TypeCode { + GeneralList = 0, + BooleanVector = 1, + GuidVector = 2, + ByteVector = 4, + ShortVector = 5, + IntVector = 6, + LongVector = 7, + RealVector = 8, + FloatVector = 9, + CharVector = 10, + SymbolVector = 11, + TimestampVector = 12, + MonthVector = 13, + DateVector = 14, + DatetimeVector = 15, + TimespanVector = 16, + MinuteVector = 17, + SecondVector = 18, + TimeVector = 19, + Table = 98, + Dictionary = 99, + UnaryPrimitive = 101, + SortedDictionary = 127, + BooleanAtom = -1, + GuidAtom = -2, + ByteAtom = -4, + ShortAtom = -5, + IntAtom = -6, + LongAtom = -7, + RealAtom = -8, + FloatAtom = -9, + CharAtom = -10, + SymbolAtom = -11, + TimestampAtom = -12, + MonthAtom = -13, + DateAtom = -14, + DatetimeAtom = -15, + TimespanAtom = -16, + MinuteAtom = -17, + SecondAtom = -18, + TimeAtom = -19, + ErrorCode = -128, +} + +impl TypeCode { + pub fn primitive(self) -> Option { + match self { + Self::BooleanAtom | Self::BooleanVector => Some(Primitive::Boolean), + Self::GuidAtom | Self::GuidVector => Some(Primitive::Guid), + Self::ByteAtom | Self::ByteVector => Some(Primitive::Byte), + Self::ShortAtom | Self::ShortVector => Some(Primitive::Short), + Self::IntAtom | Self::IntVector => Some(Primitive::Int), + Self::LongAtom | Self::LongVector => Some(Primitive::Long), + Self::RealAtom | Self::RealVector => Some(Primitive::Real), + Self::FloatAtom | Self::FloatVector => Some(Primitive::Float), + Self::CharAtom | Self::CharVector => Some(Primitive::Char), + Self::SymbolAtom | Self::SymbolVector => Some(Primitive::Symbol), + Self::TimestampAtom | Self::TimestampVector => Some(Primitive::Timestamp), + Self::MonthAtom | Self::MonthVector => Some(Primitive::Month), + Self::DateAtom | Self::DateVector => Some(Primitive::Date), + Self::DatetimeAtom | Self::DatetimeVector => Some(Primitive::Datetime), + Self::TimespanAtom | Self::TimespanVector => Some(Primitive::Timespan), + Self::MinuteAtom | Self::MinuteVector => Some(Primitive::Minute), + Self::SecondAtom | Self::SecondVector => Some(Primitive::Second), + Self::TimeAtom | Self::TimeVector => Some(Primitive::Time), + Self::GeneralList + | Self::Table + | Self::Dictionary + | Self::UnaryPrimitive + | Self::SortedDictionary + | Self::ErrorCode => None, + } + } + + pub fn shape(self) -> Shape { + match self { + Self::BooleanAtom + | Self::GuidAtom + | Self::ByteAtom + | Self::ShortAtom + | Self::IntAtom + | Self::LongAtom + | Self::RealAtom + | Self::FloatAtom + | Self::CharAtom + | Self::SymbolAtom + | Self::TimestampAtom + | Self::MonthAtom + | Self::DateAtom + | Self::DatetimeAtom + | Self::TimespanAtom + | Self::MinuteAtom + | Self::SecondAtom + | Self::TimeAtom => Shape::Atom, + Self::BooleanVector + | Self::GuidVector + | Self::ByteVector + | Self::ShortVector + | Self::IntVector + | Self::LongVector + | Self::RealVector + | Self::FloatVector + | Self::CharVector + | Self::SymbolVector + | Self::TimestampVector + | Self::MonthVector + | Self::DateVector + | Self::DatetimeVector + | Self::TimespanVector + | Self::MinuteVector + | Self::SecondVector + | Self::TimeVector => Shape::Vector, + Self::GeneralList => Shape::List, + Self::Dictionary | Self::SortedDictionary => Shape::Dictionary, + Self::Table => Shape::Table, + Self::UnaryPrimitive => Shape::UnaryPrimitive, + Self::ErrorCode => Shape::Error, + } + } +} + +impl From for i8 { + fn from(value: TypeCode) -> Self { + value as i8 + } +} + +impl TryFrom for TypeCode { + type Error = CoreError; + + fn try_from(value: i8) -> CoreResult { + match value { + 0 => Ok(Self::GeneralList), + 1 => Ok(Self::BooleanVector), + 2 => Ok(Self::GuidVector), + 4 => Ok(Self::ByteVector), + 5 => Ok(Self::ShortVector), + 6 => Ok(Self::IntVector), + 7 => Ok(Self::LongVector), + 8 => Ok(Self::RealVector), + 9 => Ok(Self::FloatVector), + 10 => Ok(Self::CharVector), + 11 => Ok(Self::SymbolVector), + 12 => Ok(Self::TimestampVector), + 13 => Ok(Self::MonthVector), + 14 => Ok(Self::DateVector), + 15 => Ok(Self::DatetimeVector), + 16 => Ok(Self::TimespanVector), + 17 => Ok(Self::MinuteVector), + 18 => Ok(Self::SecondVector), + 19 => Ok(Self::TimeVector), + 98 => Ok(Self::Table), + 99 => Ok(Self::Dictionary), + 101 => Ok(Self::UnaryPrimitive), + 127 => Ok(Self::SortedDictionary), + -1 => Ok(Self::BooleanAtom), + -2 => Ok(Self::GuidAtom), + -4 => Ok(Self::ByteAtom), + -5 => Ok(Self::ShortAtom), + -6 => Ok(Self::IntAtom), + -7 => Ok(Self::LongAtom), + -8 => Ok(Self::RealAtom), + -9 => Ok(Self::FloatAtom), + -10 => Ok(Self::CharAtom), + -11 => Ok(Self::SymbolAtom), + -12 => Ok(Self::TimestampAtom), + -13 => Ok(Self::MonthAtom), + -14 => Ok(Self::DateAtom), + -15 => Ok(Self::DatetimeAtom), + -16 => Ok(Self::TimespanAtom), + -17 => Ok(Self::MinuteAtom), + -18 => Ok(Self::SecondAtom), + -19 => Ok(Self::TimeAtom), + -128 => Ok(Self::ErrorCode), + _ => Err(CoreError::InvalidTypeCode(value)), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn attribute_round_trips() { + assert_eq!(Attribute::try_from(0).unwrap(), Attribute::None); + assert_eq!(Attribute::try_from(4).unwrap(), Attribute::Grouped); + assert!(matches!( + Attribute::try_from(9), + Err(CoreError::InvalidAttribute(9)) + )); + } + + #[test] + fn type_code_maps_to_expected_shape_and_primitive() { + let atom = TypeCode::IntAtom; + let vector = TypeCode::SymbolVector; + let list = TypeCode::GeneralList; + + assert_eq!(atom.shape(), Shape::Atom); + assert_eq!(atom.primitive(), Some(Primitive::Int)); + assert_eq!(vector.shape(), Shape::Vector); + assert_eq!(vector.primitive(), Some(Primitive::Symbol)); + assert_eq!(list.shape(), Shape::List); + assert_eq!(list.primitive(), None); + } +} diff --git a/crates/qroissant-core/src/value.rs b/crates/qroissant-core/src/value.rs new file mode 100644 index 0000000..a5892ab --- /dev/null +++ b/crates/qroissant-core/src/value.rs @@ -0,0 +1,479 @@ +use bytes::Bytes; + +use crate::error::CoreError; +use crate::error::CoreResult; +use crate::protocol::Attribute; +use crate::protocol::Primitive; +use crate::protocol::ValueType; + +/// q atom payload. +#[derive(Clone, Debug, PartialEq)] +pub enum Atom { + Boolean(bool), + Guid([u8; 16]), + Byte(u8), + Short(i16), + Int(i32), + Long(i64), + Real(f32), + Float(f64), + Char(u8), + Symbol(Bytes), + Timestamp(i64), + Month(i32), + Date(i32), + Datetime(f64), + Timespan(i64), + Minute(i32), + Second(i32), + Time(i32), +} + +impl Atom { + pub fn primitive(&self) -> Primitive { + match self { + Self::Boolean(_) => Primitive::Boolean, + Self::Guid(_) => Primitive::Guid, + Self::Byte(_) => Primitive::Byte, + Self::Short(_) => Primitive::Short, + Self::Int(_) => Primitive::Int, + Self::Long(_) => Primitive::Long, + Self::Real(_) => Primitive::Real, + Self::Float(_) => Primitive::Float, + Self::Char(_) => Primitive::Char, + Self::Symbol(_) => Primitive::Symbol, + Self::Timestamp(_) => Primitive::Timestamp, + Self::Month(_) => Primitive::Month, + Self::Date(_) => Primitive::Date, + Self::Datetime(_) => Primitive::Datetime, + Self::Timespan(_) => Primitive::Timespan, + Self::Minute(_) => Primitive::Minute, + Self::Second(_) => Primitive::Second, + Self::Time(_) => Primitive::Time, + } + } +} + +/// q homogeneous vector payload. +/// +/// All fixed-width numeric types store their data as raw [`Bytes`], enabling +/// zero-copy slicing from the IPC frame buffer during decode. Typed access +/// is provided via `as_*_slice()` methods using `bytemuck::cast_slice`. +#[derive(Clone, Debug, PartialEq)] +pub enum VectorData { + Boolean(Bytes), + Guid(Bytes), + Byte(Bytes), + Short(Bytes), + Int(Bytes), + Long(Bytes), + Real(Bytes), + Float(Bytes), + Char(Bytes), + Symbol(Vec), + Timestamp(Bytes), + Month(Bytes), + Date(Bytes), + Datetime(Bytes), + Timespan(Bytes), + Minute(Bytes), + Second(Bytes), + Time(Bytes), +} + +impl VectorData { + pub fn primitive(&self) -> Primitive { + match self { + Self::Boolean(_) => Primitive::Boolean, + Self::Guid(_) => Primitive::Guid, + Self::Byte(_) => Primitive::Byte, + Self::Short(_) => Primitive::Short, + Self::Int(_) => Primitive::Int, + Self::Long(_) => Primitive::Long, + Self::Real(_) => Primitive::Real, + Self::Float(_) => Primitive::Float, + Self::Char(_) => Primitive::Char, + Self::Symbol(_) => Primitive::Symbol, + Self::Timestamp(_) => Primitive::Timestamp, + Self::Month(_) => Primitive::Month, + Self::Date(_) => Primitive::Date, + Self::Datetime(_) => Primitive::Datetime, + Self::Timespan(_) => Primitive::Timespan, + Self::Minute(_) => Primitive::Minute, + Self::Second(_) => Primitive::Second, + Self::Time(_) => Primitive::Time, + } + } + + pub fn len(&self) -> usize { + match self { + Self::Boolean(b) | Self::Byte(b) | Self::Char(b) => b.len(), + Self::Guid(b) => b.len() / 16, + Self::Short(b) => b.len() / 2, + Self::Int(b) + | Self::Month(b) + | Self::Date(b) + | Self::Minute(b) + | Self::Second(b) + | Self::Time(b) + | Self::Real(b) => b.len() / 4, + Self::Long(b) + | Self::Timestamp(b) + | Self::Timespan(b) + | Self::Float(b) + | Self::Datetime(b) => b.len() / 8, + Self::Symbol(v) => v.len(), + } + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Returns the underlying raw bytes for non-Symbol variants. + pub fn raw_bytes(&self) -> Option<&Bytes> { + match self { + Self::Symbol(_) => None, + Self::Boolean(b) + | Self::Guid(b) + | Self::Byte(b) + | Self::Short(b) + | Self::Int(b) + | Self::Long(b) + | Self::Real(b) + | Self::Float(b) + | Self::Char(b) + | Self::Timestamp(b) + | Self::Month(b) + | Self::Date(b) + | Self::Datetime(b) + | Self::Timespan(b) + | Self::Minute(b) + | Self::Second(b) + | Self::Time(b) => Some(b), + } + } + + pub fn as_i16_slice(&self) -> &[i16] { + match self { + Self::Short(b) => bytemuck::cast_slice(b), + _ => panic!("as_i16_slice called on {:?}", self.primitive()), + } + } + + pub fn as_i32_slice(&self) -> &[i32] { + match self { + Self::Int(b) + | Self::Month(b) + | Self::Date(b) + | Self::Minute(b) + | Self::Second(b) + | Self::Time(b) => bytemuck::cast_slice(b), + _ => panic!("as_i32_slice called on {:?}", self.primitive()), + } + } + + pub fn as_i64_slice(&self) -> &[i64] { + match self { + Self::Long(b) | Self::Timestamp(b) | Self::Timespan(b) => bytemuck::cast_slice(b), + _ => panic!("as_i64_slice called on {:?}", self.primitive()), + } + } + + pub fn as_f32_slice(&self) -> &[f32] { + match self { + Self::Real(b) => bytemuck::cast_slice(b), + _ => panic!("as_f32_slice called on {:?}", self.primitive()), + } + } + + pub fn as_f64_slice(&self) -> &[f64] { + match self { + Self::Float(b) | Self::Datetime(b) => bytemuck::cast_slice(b), + _ => panic!("as_f64_slice called on {:?}", self.primitive()), + } + } + + // Construction helpers for tests and ingestion paths. + + pub fn from_i16s(values: &[i16]) -> Self { + Self::Short(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_i32s(values: &[i32]) -> Self { + Self::Int(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_i64s(values: &[i64]) -> Self { + Self::Long(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_f32s(values: &[f32]) -> Self { + Self::Real(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_f64s(values: &[f64]) -> Self { + Self::Float(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_guids(values: &[[u8; 16]]) -> Self { + let mut buf = Vec::with_capacity(values.len() * 16); + for guid in values { + buf.extend_from_slice(guid); + } + Self::Guid(Bytes::from(buf)) + } + + pub fn from_timestamps(values: &[i64]) -> Self { + Self::Timestamp(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_months(values: &[i32]) -> Self { + Self::Month(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_dates(values: &[i32]) -> Self { + Self::Date(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_datetimes(values: &[f64]) -> Self { + Self::Datetime(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_timespans(values: &[i64]) -> Self { + Self::Timespan(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_minutes(values: &[i32]) -> Self { + Self::Minute(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_seconds(values: &[i32]) -> Self { + Self::Second(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } + + pub fn from_times(values: &[i32]) -> Self { + Self::Time(Bytes::copy_from_slice(bytemuck::cast_slice(values))) + } +} + +/// q homogeneous vector with an attached q attribute. +#[derive(Clone, Debug, PartialEq)] +pub struct Vector { + attribute: Attribute, + data: VectorData, +} + +impl Vector { + pub fn new(attribute: Attribute, data: VectorData) -> Self { + Self { attribute, data } + } + + pub fn attribute(&self) -> Attribute { + self.attribute + } + + pub fn primitive(&self) -> Primitive { + self.data.primitive() + } + + pub fn len(&self) -> usize { + self.data.len() + } + + pub fn is_empty(&self) -> bool { + self.data.is_empty() + } + + pub fn data(&self) -> &VectorData { + &self.data + } +} + +/// q general list. +#[derive(Clone, Debug, PartialEq)] +pub struct List { + attribute: Attribute, + values: Vec, +} + +impl List { + pub fn new(attribute: Attribute, values: Vec) -> Self { + Self { attribute, values } + } + + pub fn attribute(&self) -> Attribute { + self.attribute + } + + pub fn len(&self) -> usize { + self.values.len() + } + + pub fn is_empty(&self) -> bool { + self.values.is_empty() + } + + pub fn values(&self) -> &[Value] { + &self.values + } +} + +/// q dictionary. +#[derive(Clone, Debug, PartialEq)] +pub struct Dictionary { + sorted: bool, + keys: Box, + values: Box, +} + +impl Dictionary { + pub fn new(sorted: bool, keys: Value, values: Value) -> Self { + Self { + sorted, + keys: Box::new(keys), + values: Box::new(values), + } + } + + pub fn sorted(&self) -> bool { + self.sorted + } + + pub fn keys(&self) -> &Value { + &self.keys + } + + pub fn values(&self) -> &Value { + &self.values + } + + pub fn len(&self) -> usize { + self.keys.len() + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn validate(&self) -> CoreResult<()> { + if self.keys.len() != self.values.len() { + return Err(CoreError::InvalidStructure(format!( + "q dictionary key/value lengths differ: {} != {}", + self.keys.len(), + self.values.len() + ))); + } + + Ok(()) + } +} + +/// q table. +#[derive(Clone, Debug, PartialEq)] +pub struct Table { + attribute: Attribute, + column_names: Vec, + columns: Vec, +} + +impl Table { + pub fn new(attribute: Attribute, column_names: Vec, columns: Vec) -> Self { + Self { + attribute, + column_names, + columns, + } + } + + pub fn attribute(&self) -> Attribute { + self.attribute + } + + pub fn column_names(&self) -> &[Bytes] { + &self.column_names + } + + pub fn columns(&self) -> &[Value] { + &self.columns + } + + pub fn num_columns(&self) -> usize { + self.columns.len() + } + + pub fn len(&self) -> usize { + self.columns.first().map_or(0, Value::len) + } + + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + pub fn validate(&self) -> CoreResult<()> { + if self.column_names.len() != self.columns.len() { + return Err(CoreError::InvalidStructure(format!( + "q table column name count {} does not match column count {}", + self.column_names.len(), + self.columns.len() + ))); + } + + if let Some(expected_rows) = self.columns.first().map(Value::len) { + for column in self.columns.iter().skip(1) { + if column.len() != expected_rows { + return Err(CoreError::InvalidStructure(format!( + "q table column lengths differ: expected {expected_rows}, found {}", + column.len() + ))); + } + } + } + + Ok(()) + } +} + +/// Decoded q value subset currently supported by the rewrite. +#[derive(Clone, Debug, PartialEq)] +pub enum Value { + Atom(Atom), + Vector(Vector), + List(List), + Dictionary(Dictionary), + Table(Table), + UnaryPrimitive { opcode: i8 }, +} + +impl Value { + pub fn qtype(&self) -> ValueType { + match self { + Self::Atom(atom) => ValueType::atom(atom.primitive()), + Self::Vector(vector) => ValueType::vector(vector.primitive(), vector.attribute()), + Self::List(list) => ValueType::list(list.attribute()), + Self::Dictionary(dictionary) => ValueType::dictionary(dictionary.sorted()), + Self::Table(table) => ValueType::table(table.attribute()), + Self::UnaryPrimitive { .. } => ValueType::unary_primitive(), + } + } + + pub fn len(&self) -> usize { + match self { + Self::Atom(_) | Self::UnaryPrimitive { .. } => 1, + Self::Vector(vector) => vector.len(), + Self::List(list) => list.len(), + Self::Dictionary(dictionary) => dictionary.len(), + Self::Table(table) => table.len(), + } + } + + pub fn is_empty(&self) -> bool { + match self { + Self::Atom(_) | Self::UnaryPrimitive { .. } => false, + Self::Vector(vector) => vector.is_empty(), + Self::List(list) => list.is_empty(), + Self::Dictionary(dictionary) => dictionary.is_empty(), + Self::Table(table) => table.is_empty(), + } + } +} diff --git a/crates/qroissant-kernels/Cargo.toml b/crates/qroissant-kernels/Cargo.toml new file mode 100644 index 0000000..ad499f9 --- /dev/null +++ b/crates/qroissant-kernels/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "qroissant-kernels" +version.workspace = true +edition.workspace = true +license.workspace = true +publish = false + +[lib] +name = "qroissant_kernels" +path = "src/lib.rs" + diff --git a/crates/qroissant-kernels/src/boolean.rs b/crates/qroissant-kernels/src/boolean.rs new file mode 100644 index 0000000..93f2099 --- /dev/null +++ b/crates/qroissant-kernels/src/boolean.rs @@ -0,0 +1,121 @@ +//! SIMD boolean bit-packing for q → Arrow projection. +//! +//! q stores boolean vectors as one byte per element (`0` = false, `1` = true, +//! `2` = null on the wire — see [`crate::nulls::Q_NULL_BOOLEAN_WIRE`]). +//! Arrow `BooleanArray` uses a compact bitmap: one bit per element, LSB-first +//! within each byte. +//! +//! [`pack_bool_bytes`] converts a q boolean byte slice into an Arrow-compatible +//! packed bitmap using SIMD comparisons, processing `N` bytes per iteration. + +use std::simd::prelude::*; + +/// Packs a slice of q boolean bytes into an Arrow-compatible LSB-first bitmap. +/// +/// Each source byte is treated as non-zero → `1` bit, zero → `0` bit. +/// Null bytes (`2`) are treated as truthy here — callers that need a separate +/// null buffer should pass in a pre-filtered slice or handle nulls separately. +/// +/// Returns `(bitmap_bytes, element_count)` where `bitmap_bytes` is the packed +/// bitmap (length `ceil(src.len() / 8)`) and `element_count == src.len()`. +/// +/// The returned `Vec` is suitable for wrapping directly into an Arrow +/// `arrow_buffer::Buffer` → `arrow_array::types::BooleanBuffer`. +#[inline] +pub fn pack_bool_bytes(src: &[u8]) -> (Vec, usize) { + let len = src.len(); + let out_len = len.div_ceil(8); + let mut out = vec![0u8; out_len]; + + const N: usize = 8; + let zero_v = Simd::::splat(0u8); + // Number of full 8-byte chunks we can process with SIMD. + let n_aligned = (len / N) * N; + + for (i, chunk) in src[..n_aligned].chunks_exact(N).enumerate() { + let v = Simd::::from_slice(chunk); + // Compare each byte to zero: non-zero → true (1-bit), zero → false (0-bit). + let mask: std::simd::Mask = v.simd_ne(zero_v); + // `to_bitmask()` produces a u8 with one bit per lane, LSB = lane 0. + out[i] = mask.to_bitmask() as u8; + } + + // Scalar tail (fewer than N elements remain). + if n_aligned < len { + let mut tail_byte = 0u8; + for (bit, &b) in src[n_aligned..].iter().enumerate() { + if b != 0 { + tail_byte |= 1u8 << bit; + } + } + out[n_aligned / N] = tail_byte; + } + + (out, len) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn pack_empty() { + let (bm, n) = pack_bool_bytes(&[]); + assert_eq!(n, 0); + assert!(bm.is_empty()); + } + + #[test] + fn pack_all_false() { + let src = [0u8; 16]; + let (bm, n) = pack_bool_bytes(&src); + assert_eq!(n, 16); + assert_eq!(bm, [0u8, 0u8]); + } + + #[test] + fn pack_all_true() { + let src = [1u8; 16]; + let (bm, n) = pack_bool_bytes(&src); + assert_eq!(n, 16); + assert_eq!(bm, [0xFF, 0xFF]); + } + + #[test] + fn pack_lsb_first() { + // Only the first element is true → bit 0 of byte 0 should be set. + let mut src = [0u8; 8]; + src[0] = 1; + let (bm, _) = pack_bool_bytes(&src); + assert_eq!(bm[0], 0b00000001); + } + + #[test] + fn pack_last_element_in_first_chunk() { + // Only the 8th element (index 7) is true → bit 7 of byte 0. + let mut src = [0u8; 8]; + src[7] = 1; + let (bm, _) = pack_bool_bytes(&src); + assert_eq!(bm[0], 0b10000000); + } + + #[test] + fn pack_tail_single() { + // 9 elements: first 8 all false, 9th is true → bit 0 of byte 1. + let mut src = [0u8; 9]; + src[8] = 1; + let (bm, n) = pack_bool_bytes(&src); + assert_eq!(n, 9); + assert_eq!(bm.len(), 2); + assert_eq!(bm[0], 0x00); + assert_eq!(bm[1], 0b00000001); + } + + #[test] + fn pack_non_zero_is_true() { + // Any non-zero value should count as true. + let src = [0u8, 2u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8]; + let (bm, _) = pack_bool_bytes(&src); + assert_eq!(bm[0], 0b00000010); + } +} diff --git a/crates/qroissant-kernels/src/lib.rs b/crates/qroissant-kernels/src/lib.rs new file mode 100644 index 0000000..f76bec1 --- /dev/null +++ b/crates/qroissant-kernels/src/lib.rs @@ -0,0 +1,25 @@ +#![feature(portable_simd)] +//! SIMD and hot kernels for qroissant. +//! +//! This crate provides two categories of primitives: +//! +//! 1. **Constants** – null sentinels and epoch-offset values used throughout +//! the workspace to interpret q IPC wire bytes. +//! +//! 2. **Scalar transforms** – functions that operate on typed Rust slices. +//! These are correct scalar implementations; future iterations will add +//! `portable_simd` specialisations in this same crate without changing the +//! public API consumed by `qroissant-arrow`. +//! +//! # Architecture rule +//! All nightly-sensitive code (`portable_simd`, intrinsics, etc.) must live +//! in this crate so that the rest of the workspace can remain on stable if +//! needed and so that performance-sensitive code has a single home. + +pub mod boolean; +pub mod nulls; +pub mod temporal; + +pub use boolean::*; +pub use nulls::*; +pub use temporal::*; diff --git a/crates/qroissant-kernels/src/nulls.rs b/crates/qroissant-kernels/src/nulls.rs new file mode 100644 index 0000000..498e20a --- /dev/null +++ b/crates/qroissant-kernels/src/nulls.rs @@ -0,0 +1,371 @@ +//! Null sentinel constants and SIMD-accelerated null-detection helpers for q IPC types. +//! +//! In q's IPC protocol each fixed-width primitive has a dedicated sentinel value +//! that represents a missing (null) element. These constants are consumed by +//! both the Arrow projection layer and any serialisation code that needs to +//! round-trip q nullability semantics. +//! +//! Each `validity_*` function returns `None` when the slice contains no nulls +//! (the fast path: callers can skip building a null buffer entirely) or +//! `Some(Vec)` where `true` means the element is valid. The null check +//! uses `portable_simd` for throughput; the validity-vector build falls back to +//! a scalar loop because nulls are the uncommon case. + +use std::simd::prelude::*; + +/// Null sentinel for q short (i16). +pub const Q_NULL_SHORT: i16 = i16::MIN; + +/// Null sentinel for q int (i32). +pub const Q_NULL_INT: i32 = i32::MIN; + +/// Null sentinel for q long (i64). +pub const Q_NULL_LONG: i64 = i64::MIN; + +/// Null sentinel for q timestamp (i64 nanoseconds since 2000.01.01). +pub const Q_NULL_TIMESTAMP: i64 = i64::MIN; + +/// Null sentinel for q month (i32 months since 2000.01). +pub const Q_NULL_MONTH: i32 = i32::MIN; + +/// Null sentinel for q date (i32 days since 2000.01.01). +pub const Q_NULL_DATE: i32 = i32::MIN; + +/// Null sentinel for q timespan (i64 nanoseconds). +pub const Q_NULL_TIMESPAN: i64 = i64::MIN; + +/// Null sentinel for q minute (i32 minutes). +pub const Q_NULL_MINUTE: i32 = i32::MIN; + +/// Null sentinel for q second (i32 seconds). +pub const Q_NULL_SECOND: i32 = i32::MIN; + +/// Null sentinel for q time (i32 milliseconds). +pub const Q_NULL_TIME: i32 = i32::MIN; + +/// Byte value used to encode a null boolean in the raw q IPC wire format. +/// `0` = false, `1` = true, `2` = null. +pub const Q_NULL_BOOLEAN_WIRE: u8 = 2; + +// --------------------------------------------------------------------------- +// Infinity sentinel constants +// --------------------------------------------------------------------------- + +/// Positive infinity sentinel for q short (i16). +pub const Q_INF_SHORT: i16 = i16::MAX; +/// Negative infinity sentinel for q short (i16). +pub const Q_NINF_SHORT: i16 = i16::MIN + 1; + +/// Positive infinity sentinel for q int (i32). +pub const Q_INF_INT: i32 = i32::MAX; +/// Negative infinity sentinel for q int (i32). +pub const Q_NINF_INT: i32 = i32::MIN + 1; + +/// Positive infinity sentinel for q long (i64). +pub const Q_INF_LONG: i64 = i64::MAX; +/// Negative infinity sentinel for q long (i64). +pub const Q_NINF_LONG: i64 = i64::MIN + 1; + +/// Positive infinity sentinel for q real (f32). +pub const Q_INF_REAL: f32 = f32::INFINITY; +/// Negative infinity sentinel for q real (f32). +pub const Q_NINF_REAL: f32 = f32::NEG_INFINITY; + +/// Positive infinity sentinel for q float (f64). +pub const Q_INF_FLOAT: f64 = f64::INFINITY; +/// Negative infinity sentinel for q float (f64). +pub const Q_NINF_FLOAT: f64 = f64::NEG_INFINITY; + +/// Positive infinity sentinel for q timestamp (i64 nanoseconds). +pub const Q_INF_TIMESTAMP: i64 = i64::MAX; +/// Negative infinity sentinel for q timestamp (i64 nanoseconds). +pub const Q_NINF_TIMESTAMP: i64 = i64::MIN + 1; + +/// Positive infinity sentinel for q timespan (i64 nanoseconds). +pub const Q_INF_TIMESPAN: i64 = i64::MAX; +/// Negative infinity sentinel for q timespan (i64 nanoseconds). +pub const Q_NINF_TIMESPAN: i64 = i64::MIN + 1; + +/// Positive infinity sentinel for q date (i32 days). +pub const Q_INF_DATE: i32 = i32::MAX; +/// Negative infinity sentinel for q date (i32 days). +pub const Q_NINF_DATE: i32 = i32::MIN + 1; + +/// Positive infinity sentinel for q month (i32 months). +pub const Q_INF_MONTH: i32 = i32::MAX; +/// Negative infinity sentinel for q month (i32 months). +pub const Q_NINF_MONTH: i32 = i32::MIN + 1; + +/// Positive infinity sentinel for q minute (i32 minutes). +pub const Q_INF_MINUTE: i32 = i32::MAX; +/// Negative infinity sentinel for q minute (i32 minutes). +pub const Q_NINF_MINUTE: i32 = i32::MIN + 1; + +/// Positive infinity sentinel for q second (i32 seconds). +pub const Q_INF_SECOND: i32 = i32::MAX; +/// Negative infinity sentinel for q second (i32 seconds). +pub const Q_NINF_SECOND: i32 = i32::MIN + 1; + +/// Positive infinity sentinel for q time (i32 milliseconds). +pub const Q_INF_TIME: i32 = i32::MAX; +/// Negative infinity sentinel for q time (i32 milliseconds). +pub const Q_NINF_TIME: i32 = i32::MIN + 1; + +// --------------------------------------------------------------------------- +// SIMD null-detection helpers +// --------------------------------------------------------------------------- + +/// Returns a validity vector for a `&[i16]` slice using [`Q_NULL_SHORT`] as +/// the sentinel. Returns `None` when no nulls are present. +#[inline] +pub fn validity_i16(values: &[i16]) -> Option> { + const N: usize = 32; + let sentinel = Simd::::splat(Q_NULL_SHORT); + let n_aligned = (values.len() / N) * N; + + let has_null = values[..n_aligned] + .chunks_exact(N) + .any(|c| Simd::::from_slice(c).simd_eq(sentinel).any()) + || values[n_aligned..].iter().any(|&v| v == Q_NULL_SHORT); + + if !has_null { + return None; + } + Some(values.iter().map(|&v| v != Q_NULL_SHORT).collect()) +} + +/// Returns a validity vector for a `&[i32]` slice using the supplied sentinel. +/// Returns `None` when no nulls are present. +#[inline] +pub fn validity_i32(values: &[i32], sentinel: i32) -> Option> { + const N: usize = 16; + let sentinel_v = Simd::::splat(sentinel); + let n_aligned = (values.len() / N) * N; + + let has_null = values[..n_aligned] + .chunks_exact(N) + .any(|c| Simd::::from_slice(c).simd_eq(sentinel_v).any()) + || values[n_aligned..].iter().any(|&v| v == sentinel); + + if !has_null { + return None; + } + Some(values.iter().map(|&v| v != sentinel).collect()) +} + +/// Returns a validity vector for a `&[i64]` slice using the supplied sentinel. +/// Returns `None` when no nulls are present. +#[inline] +pub fn validity_i64(values: &[i64], sentinel: i64) -> Option> { + const N: usize = 8; + let sentinel_v = Simd::::splat(sentinel); + let n_aligned = (values.len() / N) * N; + + let has_null = values[..n_aligned] + .chunks_exact(N) + .any(|c| Simd::::from_slice(c).simd_eq(sentinel_v).any()) + || values[n_aligned..].iter().any(|&v| v == sentinel); + + if !has_null { + return None; + } + Some(values.iter().map(|&v| v != sentinel).collect()) +} + +/// Returns a validity vector for a `&[f32]` slice where `NaN` encodes null. +/// Returns `None` when no nulls are present. +#[inline] +pub fn validity_f32(values: &[f32]) -> Option> { + const N: usize = 16; + let n_aligned = (values.len() / N) * N; + + // NaN is the only value not equal to itself. + let has_null = values[..n_aligned].chunks_exact(N).any(|c| { + let v = Simd::::from_slice(c); + v.simd_ne(v).any() + }) || values[n_aligned..].iter().any(|v| v.is_nan()); + + if !has_null { + return None; + } + Some(values.iter().map(|v| !v.is_nan()).collect()) +} + +/// Returns a validity vector for a `&[f64]` slice where `NaN` encodes null. +/// Returns `None` when no nulls are present. +#[inline] +pub fn validity_f64(values: &[f64]) -> Option> { + const N: usize = 8; + let n_aligned = (values.len() / N) * N; + + let has_null = values[..n_aligned].chunks_exact(N).any(|c| { + let v = Simd::::from_slice(c); + v.simd_ne(v).any() + }) || values[n_aligned..].iter().any(|v| v.is_nan()); + + if !has_null { + return None; + } + Some(values.iter().map(|v| !v.is_nan()).collect()) +} + +#[cfg(test)] +mod tests { + use super::*; + + // validity_i16 + #[test] + fn i16_no_nulls() { + assert_eq!(validity_i16(&[1, 2, 3, 4, 5]), None); + } + + #[test] + fn i16_with_null() { + assert_eq!( + validity_i16(&[1, Q_NULL_SHORT, 3]), + Some(vec![true, false, true]) + ); + } + + #[test] + fn i16_all_nulls() { + assert_eq!(validity_i16(&[Q_NULL_SHORT; 4]), Some(vec![false; 4])); + } + + #[test] + fn i16_empty() { + assert_eq!(validity_i16(&[]), None); + } + + #[test] + fn i16_single_null() { + assert_eq!(validity_i16(&[Q_NULL_SHORT]), Some(vec![false])); + } + + #[test] + fn i16_null_in_remainder() { + let mut data: Vec = (1..=9).collect(); + data[8] = Q_NULL_SHORT; + let v = validity_i16(&data).unwrap(); + assert!(!v[8]); + assert!(v[0]); + } + + // validity_i32 + #[test] + fn i32_no_nulls() { + assert_eq!(validity_i32(&[1, 2, 3], Q_NULL_INT), None); + } + + #[test] + fn i32_with_null() { + assert_eq!( + validity_i32(&[1, Q_NULL_INT, 3], Q_NULL_INT), + Some(vec![true, false, true]) + ); + } + + #[test] + fn i32_empty() { + assert_eq!(validity_i32(&[], Q_NULL_INT), None); + } + + #[test] + fn i32_all_nulls() { + assert_eq!( + validity_i32(&[Q_NULL_INT; 3], Q_NULL_INT), + Some(vec![false; 3]) + ); + } + + #[test] + fn i32_null_in_remainder() { + let mut data: Vec = (1..=10).collect(); + data[9] = Q_NULL_INT; + assert!(!validity_i32(&data, Q_NULL_INT).unwrap()[9]); + } + + // validity_i64 + #[test] + fn i64_no_nulls() { + assert_eq!(validity_i64(&[1, 2, 3], Q_NULL_LONG), None); + } + + #[test] + fn i64_with_null() { + assert_eq!( + validity_i64(&[1, Q_NULL_LONG, 3], Q_NULL_LONG), + Some(vec![true, false, true]) + ); + } + + #[test] + fn i64_empty() { + assert_eq!(validity_i64(&[], Q_NULL_LONG), None); + } + + #[test] + fn i64_timestamp_sentinel() { + assert_eq!( + validity_i64(&[100, Q_NULL_TIMESTAMP, 300], Q_NULL_TIMESTAMP), + Some(vec![true, false, true]) + ); + } + + // validity_f32 + #[test] + fn f32_no_nulls() { + assert_eq!(validity_f32(&[1.0, 2.0, 3.0]), None); + } + + #[test] + fn f32_with_nan() { + assert_eq!( + validity_f32(&[1.0, f32::NAN, 3.0]), + Some(vec![true, false, true]) + ); + } + + #[test] + fn f32_empty() { + assert_eq!(validity_f32(&[]), None); + } + + #[test] + fn f32_infinity_is_not_null() { + assert_eq!(validity_f32(&[f32::INFINITY, f32::NEG_INFINITY, 1.0]), None); + } + + // validity_f64 + #[test] + fn f64_no_nulls() { + assert_eq!(validity_f64(&[1.0, 2.0, 3.0]), None); + } + + #[test] + fn f64_with_nan() { + assert_eq!( + validity_f64(&[1.0, f64::NAN, 3.0]), + Some(vec![true, false, true]) + ); + } + + #[test] + fn f64_empty() { + assert_eq!(validity_f64(&[]), None); + } + + #[test] + fn f64_infinity_is_not_null() { + assert_eq!(validity_f64(&[f64::INFINITY, f64::NEG_INFINITY, 1.0]), None); + } + + #[test] + fn f64_large_aligned_with_null() { + let mut data = vec![1.0; 8]; + data[7] = f64::NAN; + let v = validity_f64(&data).unwrap(); + assert!(v[0]); + assert!(!v[7]); + } +} diff --git a/crates/qroissant-kernels/src/temporal.rs b/crates/qroissant-kernels/src/temporal.rs new file mode 100644 index 0000000..85c90a9 --- /dev/null +++ b/crates/qroissant-kernels/src/temporal.rs @@ -0,0 +1,317 @@ +//! Temporal conversion constants and SIMD transforms for q ↔ Arrow mapping. +//! +//! q encodes temporal values relative to the millennium epoch (2000-01-01) +//! while Arrow uses the Unix epoch (1970-01-01). The helpers here translate +//! between the two without touching Arrow types so that this crate stays free +//! of Arrow dependencies. +//! +//! Each transform function uses `portable_simd` for the aligned middle of the +//! slice and falls back to a scalar loop for the head and tail. + +use std::simd::Select; +use std::simd::prelude::*; + +use crate::nulls::Q_NULL_DATE; +use crate::nulls::Q_NULL_MINUTE; +use crate::nulls::Q_NULL_TIMESTAMP; + +/// Nanoseconds between 1970-01-01 and 2000-01-01. +pub const TIMESTAMP_OFFSET_NS: i64 = 946_684_800_000_000_000; + +/// Days between 1970-01-01 and 2000-01-01. +pub const DATE_OFFSET_DAYS: i32 = 10_957; + +/// Milliseconds in a day (used for `Datetime` float-day conversion). +pub const MILLIS_PER_DAY: f64 = 86_400_000.0; + +/// Translates a slice of q timestamps (nanoseconds since 2000-01-01) into +/// Arrow `TimestampNanosecond` values (nanoseconds since 1970-01-01) in place. +/// +/// Null elements (`i64::MIN`) are left unchanged; the Arrow null buffer +/// produced by [`crate::nulls::validity_i64`] will mask them. +#[inline] +pub fn offset_timestamps(values: &mut [i64]) { + const N: usize = 8; + let null_v = Simd::::splat(Q_NULL_TIMESTAMP); + let offset_v = Simd::::splat(TIMESTAMP_OFFSET_NS); + let n_aligned = (values.len() / N) * N; + + for chunk in values[..n_aligned].chunks_exact_mut(N) { + let v = Simd::::from_slice(chunk); + let mask = v.simd_ne(null_v); + let added = v.saturating_add(offset_v); + let result = mask.select(added, v); + chunk.copy_from_slice(&result.to_array()); + } + for v in &mut values[n_aligned..] { + if *v != Q_NULL_TIMESTAMP { + *v = v.saturating_add(TIMESTAMP_OFFSET_NS); + } + } +} + +/// Translates a slice of q dates (days since 2000-01-01) into Arrow `Date32` +/// values (days since 1970-01-01) in place. +/// +/// Null elements (`i32::MIN`) are left unchanged. +#[inline] +pub fn offset_dates(values: &mut [i32]) { + const N: usize = 16; + let null_v = Simd::::splat(Q_NULL_DATE); + let offset_v = Simd::::splat(DATE_OFFSET_DAYS); + let n_aligned = (values.len() / N) * N; + + for chunk in values[..n_aligned].chunks_exact_mut(N) { + let v = Simd::::from_slice(chunk); + let mask = v.simd_ne(null_v); + let added = v.saturating_add(offset_v); + let result = mask.select(added, v); + chunk.copy_from_slice(&result.to_array()); + } + for v in &mut values[n_aligned..] { + if *v != Q_NULL_DATE { + *v = v.saturating_add(DATE_OFFSET_DAYS); + } + } +} + +/// Translates a slice of q minute values (minutes) into Arrow `Time32Second` +/// values (seconds) in place. +/// +/// Null elements (`i32::MIN`) are left unchanged. +#[inline] +pub fn minutes_to_seconds(values: &mut [i32]) { + const N: usize = 16; + let null_v = Simd::::splat(Q_NULL_MINUTE); + let sixty_v = Simd::::splat(60_i32); + let n_aligned = (values.len() / N) * N; + + for chunk in values[..n_aligned].chunks_exact_mut(N) { + let v = Simd::::from_slice(chunk); + let mask = v.simd_ne(null_v); + // Non-null minutes multiplied by 60; null sentinels selected back in. + // Wrapping multiply is safe here: the select restores the original + // sentinel value for null lanes, so overflow in null lanes is harmless. + let multiplied = v * sixty_v; + let result = mask.select(multiplied, v); + chunk.copy_from_slice(&result.to_array()); + } + for v in &mut values[n_aligned..] { + if *v != Q_NULL_MINUTE { + *v = v.saturating_mul(60); + } + } +} + +/// Copies q timestamps (nanoseconds since 2000-01-01) from `src` into `dst`, +/// applying the Unix-epoch offset in a single SIMD pass. +/// +/// Avoids the two-pass cost of `to_vec()` + `offset_timestamps()`: +/// one read from `src`, one write to `dst`, no intermediate allocation. +/// Null elements (`i64::MIN`) are copied unchanged. +/// +/// `src` and `dst` must have the same length. +#[inline] +pub fn copy_and_offset_timestamps(src: &[i64], dst: &mut [i64]) { + debug_assert_eq!(src.len(), dst.len()); + const N: usize = 8; + let null_v = Simd::::splat(Q_NULL_TIMESTAMP); + let offset_v = Simd::::splat(TIMESTAMP_OFFSET_NS); + let n_aligned = (src.len() / N) * N; + + for (s, d) in src[..n_aligned] + .chunks_exact(N) + .zip(dst[..n_aligned].chunks_exact_mut(N)) + { + let v = Simd::::from_slice(s); + let mask = v.simd_ne(null_v); + let result = mask.select(v.saturating_add(offset_v), v); + d.copy_from_slice(&result.to_array()); + } + for (s, d) in src[n_aligned..].iter().zip(dst[n_aligned..].iter_mut()) { + *d = if *s != Q_NULL_TIMESTAMP { + s.saturating_add(TIMESTAMP_OFFSET_NS) + } else { + *s + }; + } +} + +/// Copies q dates (days since 2000-01-01) from `src` into `dst`, +/// applying the Unix-epoch offset in a single SIMD pass. +/// +/// `src` and `dst` must have the same length. +#[inline] +pub fn copy_and_offset_dates(src: &[i32], dst: &mut [i32]) { + debug_assert_eq!(src.len(), dst.len()); + const N: usize = 16; + let null_v = Simd::::splat(Q_NULL_DATE); + let offset_v = Simd::::splat(DATE_OFFSET_DAYS); + let n_aligned = (src.len() / N) * N; + + for (s, d) in src[..n_aligned] + .chunks_exact(N) + .zip(dst[..n_aligned].chunks_exact_mut(N)) + { + let v = Simd::::from_slice(s); + let mask = v.simd_ne(null_v); + let result = mask.select(v.saturating_add(offset_v), v); + d.copy_from_slice(&result.to_array()); + } + for (s, d) in src[n_aligned..].iter().zip(dst[n_aligned..].iter_mut()) { + *d = if *s != Q_NULL_DATE { + s.saturating_add(DATE_OFFSET_DAYS) + } else { + *s + }; + } +} + +/// Copies q minute values from `src` into `dst`, converting minutes → seconds +/// in a single SIMD pass. +/// +/// `src` and `dst` must have the same length. +#[inline] +pub fn copy_and_minutes_to_seconds(src: &[i32], dst: &mut [i32]) { + debug_assert_eq!(src.len(), dst.len()); + const N: usize = 16; + let null_v = Simd::::splat(Q_NULL_MINUTE); + let sixty_v = Simd::::splat(60_i32); + let n_aligned = (src.len() / N) * N; + + for (s, d) in src[..n_aligned] + .chunks_exact(N) + .zip(dst[..n_aligned].chunks_exact_mut(N)) + { + let v = Simd::::from_slice(s); + let mask = v.simd_ne(null_v); + let multiplied = v * sixty_v; + let result = mask.select(multiplied, v); + d.copy_from_slice(&result.to_array()); + } + for (s, d) in src[n_aligned..].iter().zip(dst[n_aligned..].iter_mut()) { + *d = if *s != Q_NULL_MINUTE { + s.saturating_mul(60) + } else { + *s + }; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // ----------------------------------------------------------------------- + // offset_timestamps + // ----------------------------------------------------------------------- + + #[test] + fn offset_timestamps_basic() { + // q timestamp 1 ns since 2000 -> Unix epoch ns + let mut values = vec![1i64]; + offset_timestamps(&mut values); + assert_eq!(values[0], TIMESTAMP_OFFSET_NS + 1); + } + + #[test] + fn offset_timestamps_zero() { + let mut values = vec![0i64]; + offset_timestamps(&mut values); + assert_eq!(values[0], TIMESTAMP_OFFSET_NS); + } + + #[test] + fn offset_timestamps_preserves_null() { + let mut values = vec![Q_NULL_TIMESTAMP]; + offset_timestamps(&mut values); + assert_eq!(values[0], Q_NULL_TIMESTAMP); + } + + #[test] + fn offset_timestamps_mixed() { + let mut values = vec![0, Q_NULL_TIMESTAMP, 1000, Q_NULL_TIMESTAMP, 2000]; + offset_timestamps(&mut values); + assert_eq!(values[0], TIMESTAMP_OFFSET_NS); + assert_eq!(values[1], Q_NULL_TIMESTAMP); + assert_eq!(values[2], TIMESTAMP_OFFSET_NS + 1000); + assert_eq!(values[3], Q_NULL_TIMESTAMP); + assert_eq!(values[4], TIMESTAMP_OFFSET_NS + 2000); + } + + #[test] + fn offset_timestamps_empty() { + let mut values: Vec = vec![]; + offset_timestamps(&mut values); + assert!(values.is_empty()); + } + + // ----------------------------------------------------------------------- + // offset_dates + // ----------------------------------------------------------------------- + + #[test] + fn offset_dates_basic() { + let mut values = vec![0i32]; // 2000-01-01 -> days since Unix epoch + offset_dates(&mut values); + assert_eq!(values[0], DATE_OFFSET_DAYS); + } + + #[test] + fn offset_dates_preserves_null() { + let mut values = vec![Q_NULL_DATE]; + offset_dates(&mut values); + assert_eq!(values[0], Q_NULL_DATE); + } + + #[test] + fn offset_dates_mixed() { + let mut values = vec![0, Q_NULL_DATE, 1, Q_NULL_DATE]; + offset_dates(&mut values); + assert_eq!(values[0], DATE_OFFSET_DAYS); + assert_eq!(values[1], Q_NULL_DATE); + assert_eq!(values[2], DATE_OFFSET_DAYS + 1); + assert_eq!(values[3], Q_NULL_DATE); + } + + #[test] + fn offset_dates_empty() { + let mut values: Vec = vec![]; + offset_dates(&mut values); + assert!(values.is_empty()); + } + + // ----------------------------------------------------------------------- + // minutes_to_seconds + // ----------------------------------------------------------------------- + + #[test] + fn minutes_to_seconds_basic() { + let mut values = vec![10i32]; // 10 minutes -> 600 seconds + minutes_to_seconds(&mut values); + assert_eq!(values[0], 600); + } + + #[test] + fn minutes_to_seconds_preserves_null() { + let mut values = vec![Q_NULL_MINUTE]; + minutes_to_seconds(&mut values); + assert_eq!(values[0], Q_NULL_MINUTE); + } + + #[test] + fn minutes_to_seconds_mixed() { + let mut values = vec![1, Q_NULL_MINUTE, 60]; + minutes_to_seconds(&mut values); + assert_eq!(values[0], 60); + assert_eq!(values[1], Q_NULL_MINUTE); + assert_eq!(values[2], 3600); + } + + #[test] + fn minutes_to_seconds_empty() { + let mut values: Vec = vec![]; + minutes_to_seconds(&mut values); + assert!(values.is_empty()); + } +} diff --git a/crates/qroissant-python/Cargo.toml b/crates/qroissant-python/Cargo.toml new file mode 100644 index 0000000..5ea132e --- /dev/null +++ b/crates/qroissant-python/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "qroissant-python" +version.workspace = true +edition.workspace = true +license.workspace = true +publish = false + +[lib] +name = "_native" +crate-type = ["cdylib", "rlib"] +path = "src/lib.rs" + +[dependencies] +bb8 = "0.9.0" +bytes = "1.11.1" +chrono = "0.4.44" +pyo3 = { workspace = true, features = ["extension-module"] } +pyo3-arrow = { version = "0.17.0", default-features = false } +pyo3-async-runtimes = { version = "0.28.0", features = ["tokio-runtime"] } +qroissant-arrow = { path = "../qroissant-arrow" } +qroissant-core = { path = "../qroissant-core" } +qroissant-kernels = { path = "../qroissant-kernels" } +qroissant-transport = { path = "../qroissant-transport" } +r2d2 = "0.8.10" +tabled = "0.17.0" +thiserror = "2.0.18" +tokio = { version = "1.48.0", features = ["io-util", "net", "rt-multi-thread", "sync", "time"] } diff --git a/crates/qroissant-python/src/client.rs b/crates/qroissant-python/src/client.rs new file mode 100644 index 0000000..91fea0b --- /dev/null +++ b/crates/qroissant-python/src/client.rs @@ -0,0 +1,1597 @@ +use std::io::Read; +use std::io::Write; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::Mutex; +use std::task::Context; +use std::task::Poll; +use std::thread; +use std::time::Duration; + +use bb8::ManageConnection as AsyncManageConnection; +use pyo3::PyClass; +use pyo3::prelude::*; +use pyo3::types::PyAny; +use pyo3::types::PyModule; +use pyo3_async_runtimes::tokio::future_into_py; +use qroissant_core::Value; +use qroissant_transport::AsyncPooledTransport; +use qroissant_transport::AsyncTransport; +use qroissant_transport::QIPC_HEADER_LEN; +use qroissant_transport::SyncPooledTransport; +use qroissant_transport::SyncTransport; +use qroissant_transport::begin_streaming_frame_over; +use qroissant_transport::begin_streaming_frame_over_async; +use qroissant_transport::connect_tcp_transport; +use qroissant_transport::connect_tcp_transport_async; +#[cfg(unix)] +use qroissant_transport::connect_unix_transport; +#[cfg(unix)] +use qroissant_transport::connect_unix_transport_async; +use qroissant_transport::encode_sync_query; +use qroissant_transport::extract_q_error; +use qroissant_transport::request_frame_streaming_over; +use qroissant_transport::request_value_pipelined_over_async; +use qroissant_transport::validate_response_frame; +use qroissant_transport::validate_response_header_bytes; +use r2d2::ManageConnection; +use r2d2::PooledConnection; +use tokio::io::AsyncRead; +use tokio::io::AsyncWrite; +use tokio::io::ReadBuf; +use tokio::sync::Mutex as AsyncMutex; +use tokio::time::sleep; + +use crate::errors::PythonError; +use crate::errors::to_py_err; +use crate::raw_response::AsyncStreamingLease; +use crate::raw_response::BlockingAsyncBridge; +use crate::raw_response::RawResponse; +use crate::raw_response::SyncRawLease; +use crate::serde::decode_core_value; +use crate::types::DecodeOptions; +use crate::types::Endpoint; +use crate::types::MessageHeader; +use crate::types::PoolMetrics; +use crate::types::PoolOptions; +use crate::values::core_value_to_python_with_opts; + +fn async_self(py: Python<'_>, constructor: F) -> PyResult> +where + T: PyClass + Send + Sync + 'static, + F: FnOnce(Python<'_>) -> PyResult> + Send + 'static, +{ + future_into_py(py, async move { + Python::attach(|py| constructor(py).map(|value| value.into_any())) + }) +} + +fn connect_sync_transport(endpoint: &Endpoint) -> Result { + endpoint.validate().map_err(PythonError::Transport)?; + match endpoint.scheme_value() { + "tcp" => connect_tcp_transport( + endpoint.host_value().unwrap_or_default(), + endpoint.port_value().unwrap_or_default(), + endpoint.username_deref(), + endpoint.password_deref(), + endpoint.timeout_ms_value(), + ) + .map_err(crate::errors::map_transport_error), + #[cfg(unix)] + "unix" => connect_unix_transport( + endpoint.path_value().unwrap_or_default(), + endpoint.username_deref(), + endpoint.password_deref(), + endpoint.timeout_ms_value(), + ) + .map_err(crate::errors::map_transport_error), + other => Err(PythonError::Transport(format!( + "unsupported endpoint scheme {other:?}" + ))), + } +} + +async fn connect_async_transport(endpoint: &Endpoint) -> Result { + endpoint.validate().map_err(PythonError::Transport)?; + match endpoint.scheme_value() { + "tcp" => connect_tcp_transport_async( + endpoint.host_value().unwrap_or_default(), + endpoint.port_value().unwrap_or_default(), + endpoint.username_deref(), + endpoint.password_deref(), + endpoint.timeout_ms_value(), + ) + .await + .map_err(crate::errors::map_transport_error), + #[cfg(unix)] + "unix" => connect_unix_transport_async( + endpoint.path_value().unwrap_or_default(), + endpoint.username_deref(), + endpoint.password_deref(), + endpoint.timeout_ms_value(), + ) + .await + .map_err(crate::errors::map_transport_error), + other => Err(PythonError::Transport(format!( + "unsupported endpoint scheme {other:?}" + ))), + } +} + +fn validate_success_response(frame: &[u8]) -> Result<(), PythonError> { + validate_response_frame(frame).map_err(crate::errors::map_transport_error)?; + if let Some(message) = extract_q_error(frame).map_err(crate::errors::map_transport_error)? { + return Err(PythonError::QRuntime(message)); + } + Ok(()) +} + +fn response_header(header_bytes: [u8; QIPC_HEADER_LEN]) -> Result { + let header = + validate_response_header_bytes(header_bytes).map_err(crate::errors::map_transport_error)?; + Ok(MessageHeader::from(header)) +} + +fn retryable_pool_error(error: &PythonError) -> bool { + matches!(error, PythonError::Transport(_) | PythonError::Pool(_)) +} + +type SyncRawQuery = ( + MessageHeader, + [u8; QIPC_HEADER_LEN], + usize, + Box, +); + +type AsyncConnectionRawQuery = ( + MessageHeader, + [u8; QIPC_HEADER_LEN], + usize, + AsyncConnectionLease, +); + +type AsyncPoolRawQuery = ( + MessageHeader, + [u8; QIPC_HEADER_LEN], + usize, + AsyncPooledConnectionLease, +); + +#[derive(Debug, thiserror::Error)] +enum PoolBackendError { + #[error("{0}")] + Message(String), +} + +impl From for PoolBackendError { + fn from(value: PythonError) -> Self { + Self::Message(value.to_string()) + } +} + +fn sync_backoff(delay_ms: u64) { + if delay_ms > 0 { + thread::sleep(Duration::from_millis(delay_ms)); + } +} + +async fn async_backoff(delay_ms: u64) { + if delay_ms > 0 { + sleep(Duration::from_millis(delay_ms)).await; + } +} + +struct ConnectionSlotState { + transport: Option, + closed: bool, +} + +struct ConnectionLease { + state: Arc>, + transport: Option, + reusable: bool, +} + +impl ConnectionLease { + fn new(state: Arc>, transport: SyncTransport) -> Self { + Self { + state, + transport: Some(transport), + reusable: false, + } + } +} + +impl Read for ConnectionLease { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.transport + .as_mut() + .expect("active connection lease must hold a transport") + .read(buf) + } +} + +impl Write for ConnectionLease { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.transport + .as_mut() + .expect("active connection lease must hold a transport") + .write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.transport + .as_mut() + .expect("active connection lease must hold a transport") + .flush() + } +} + +impl SyncRawLease for ConnectionLease { + fn mark_reusable(&mut self) { + self.reusable = true; + } + + fn abandon(&mut self) { + self.reusable = false; + } +} + +impl Drop for ConnectionLease { + fn drop(&mut self) { + let Some(mut transport) = self.transport.take() else { + return; + }; + + let mut state = match self.state.lock() { + Ok(state) => state, + Err(_) => { + let _ = transport.shutdown(); + return; + } + }; + + if self.reusable && !state.closed && state.transport.is_none() { + state.transport = Some(transport); + return; + } + + let _ = transport.shutdown(); + } +} + +struct ConnectionCore { + state: Arc>, +} + +impl ConnectionCore { + fn connect(endpoint: &Endpoint) -> Result { + Ok(Self { + state: Arc::new(Mutex::new(ConnectionSlotState { + transport: Some(connect_sync_transport(endpoint)?), + closed: false, + })), + }) + } + + fn checkout(&self) -> Result { + let mut state = self + .state + .lock() + .map_err(|_| PythonError::Operation("connection lock is poisoned".to_string()))?; + if state.closed { + return Err(PythonError::Operation("connection is closed".to_string())); + } + let transport = state.transport.take().ok_or_else(|| { + PythonError::Operation("connection is busy with an active raw response".to_string()) + })?; + Ok(ConnectionLease::new(self.state.clone(), transport)) + } + + /// Sends a synchronous query and reads the full response frame. + fn query_frame(&self, expr: &str) -> Result, PythonError> { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let mut lease = self.checkout()?; + match request_frame_streaming_over(&mut lease, &payload) { + Ok(frame) => { + validate_response_frame(&frame).map_err(crate::errors::map_transport_error)?; + lease.mark_reusable(); + Ok(frame) + } + Err(error) => { + lease.abandon(); + Err(crate::errors::map_transport_error(error)) + } + } + } + + /// Sends a synchronous query and returns a streaming raw-response handle. + fn begin_raw_query(&self, expr: &str) -> Result { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let mut lease = self.checkout()?; + match begin_streaming_frame_over(&mut lease, &payload) { + Ok((header_bytes, remaining_body)) => { + let header = match response_header(header_bytes) { + Ok(header) => header, + Err(error) => { + lease.abandon(); + return Err(error); + } + }; + Ok((header, header_bytes, remaining_body, Box::new(lease))) + } + Err(error) => { + lease.abandon(); + Err(crate::errors::map_transport_error(error)) + } + } + } + + fn close(&self) -> Result<(), PythonError> { + let transport = { + let mut state = self + .state + .lock() + .map_err(|_| PythonError::Operation("connection lock is poisoned".to_string()))?; + state.closed = true; + state.transport.take() + }; + if let Some(mut transport) = transport { + transport + .shutdown() + .map_err(|error| PythonError::Transport(error.to_string()))?; + } + Ok(()) + } +} + +/// Connection slot state machine for async connections. +/// +/// ```text +/// Disconnected ─── connect ──→ Busy ─── success ──→ Ready +/// ↑ │ │ +/// └── connect fail ────────┘ │ +/// └── mark_broken ─────────────────────────────┘ +/// │ +/// Closed ←── close ───┘ +/// ``` +enum AsyncConnectionSlotState { + Disconnected, + Ready(AsyncTransport), + Busy, + Closed, +} + +struct AsyncConnectionLease { + state: Arc>, + transport: Option, + reusable: bool, +} + +impl AsyncConnectionLease { + fn new(state: Arc>, transport: AsyncTransport) -> Self { + Self { + state, + transport: Some(transport), + reusable: false, + } + } + + fn restore_transport( + state: &mut AsyncConnectionSlotState, + transport: AsyncTransport, + reusable: bool, + ) { + if reusable { + if matches!(state, AsyncConnectionSlotState::Busy) { + *state = AsyncConnectionSlotState::Ready(transport); + return; + } + } else if matches!(state, AsyncConnectionSlotState::Busy) { + *state = AsyncConnectionSlotState::Disconnected; + } + drop(transport); + } +} + +impl AsyncRead for AsyncConnectionLease { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new( + self.transport + .as_mut() + .expect("active async connection lease must hold a transport"), + ) + .poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncConnectionLease { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new( + self.transport + .as_mut() + .expect("active async connection lease must hold a transport"), + ) + .poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new( + self.transport + .as_mut() + .expect("active async connection lease must hold a transport"), + ) + .poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new( + self.transport + .as_mut() + .expect("active async connection lease must hold a transport"), + ) + .poll_shutdown(cx) + } +} + +impl AsyncStreamingLease for AsyncConnectionLease { + fn mark_reusable(&mut self) { + self.reusable = true; + } + + fn abandon(&mut self) { + self.reusable = false; + } +} + +impl Drop for AsyncConnectionLease { + fn drop(&mut self) { + let Some(transport) = self.transport.take() else { + return; + }; + + if let Ok(mut state) = self.state.try_lock() { + Self::restore_transport(&mut state, transport, self.reusable); + return; + } + + if let Ok(handle) = tokio::runtime::Handle::try_current() { + let state = self.state.clone(); + let reusable = self.reusable; + handle.spawn(async move { + let mut state = state.lock().await; + Self::restore_transport(&mut state, transport, reusable); + }); + return; + } + + let mut state = self.state.blocking_lock(); + Self::restore_transport(&mut state, transport, self.reusable); + } +} + +struct AsyncConnectionCore { + endpoint: Endpoint, + state: Arc>, +} + +impl AsyncConnectionCore { + fn new(endpoint: Endpoint) -> Result { + endpoint.validate().map_err(PythonError::Transport)?; + Ok(Self { + endpoint, + state: Arc::new(AsyncMutex::new(AsyncConnectionSlotState::Disconnected)), + }) + } + + async fn checkout(&self) -> Result { + let should_connect = { + let mut state = self.state.lock().await; + match std::mem::replace(&mut *state, AsyncConnectionSlotState::Busy) { + AsyncConnectionSlotState::Ready(transport) => { + return Ok(AsyncConnectionLease::new(self.state.clone(), transport)); + } + AsyncConnectionSlotState::Disconnected => true, + AsyncConnectionSlotState::Busy => { + *state = AsyncConnectionSlotState::Busy; + return Err(PythonError::Operation( + "connection is busy with an active raw response".to_string(), + )); + } + AsyncConnectionSlotState::Closed => { + *state = AsyncConnectionSlotState::Closed; + return Err(PythonError::Operation("connection is closed".to_string())); + } + } + }; + + debug_assert!(should_connect); + match connect_async_transport(&self.endpoint).await { + Ok(transport) => Ok(AsyncConnectionLease::new(self.state.clone(), transport)), + Err(error) => { + let mut state = self.state.lock().await; + if !matches!(&*state, AsyncConnectionSlotState::Closed) { + *state = AsyncConnectionSlotState::Disconnected; + } + Err(error) + } + } + } + + async fn query_value(&self, expr: &str) -> Result { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let mut lease = self.checkout().await?; + match qroissant_transport::request_value_pipelined_over_async(&mut lease, &payload).await { + Ok(value) => { + lease.mark_reusable(); + Ok(value) + } + Err(error) => { + lease.abandon(); + Err(crate::errors::map_transport_error(error)) + } + } + } + + /// Sends an async query and returns a streaming raw-response handle. + async fn begin_raw_query(&self, expr: &str) -> Result { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let mut lease = self.checkout().await?; + match begin_streaming_frame_over_async(&mut lease, &payload).await { + Ok((header_bytes, remaining_body)) => { + let header = match response_header(header_bytes) { + Ok(header) => header, + Err(error) => { + lease.abandon(); + return Err(error); + } + }; + Ok((header, header_bytes, remaining_body, lease)) + } + Err(error) => { + lease.abandon(); + Err(crate::errors::map_transport_error(error)) + } + } + } + + async fn close(&self) -> Result<(), PythonError> { + let transport = { + let mut state = self.state.lock().await; + match std::mem::replace(&mut *state, AsyncConnectionSlotState::Closed) { + AsyncConnectionSlotState::Ready(transport) => Some(transport), + AsyncConnectionSlotState::Disconnected + | AsyncConnectionSlotState::Busy + | AsyncConnectionSlotState::Closed => None, + } + }; + if let Some(mut transport) = transport { + transport + .shutdown() + .await + .map_err(|error| PythonError::Transport(error.to_string()))?; + } + Ok(()) + } +} + +fn sync_request(conn: &mut SyncPooledTransport, payload: &[u8]) -> Result, PythonError> { + match request_frame_streaming_over(conn, payload) { + Ok(frame) => Ok(frame), + Err(error) => { + conn.mark_broken(); + Err(crate::errors::map_transport_error(error)) + } + } +} + +fn sync_send_query(conn: &mut SyncPooledTransport, query: &str) -> Result, PythonError> { + let payload = encode_sync_query(query).map_err(crate::errors::map_transport_error)?; + sync_request(conn, &payload) +} + +fn sync_validate_connection( + conn: &mut SyncPooledTransport, + healthcheck_query: Option<&str>, +) -> Result<(), PoolBackendError> { + if conn.is_broken() { + return Err(PoolBackendError::Message( + "pooled connection is broken".to_string(), + )); + } + if let Some(query) = healthcheck_query { + let response = sync_send_query(conn, query)?; + validate_success_response(&response)?; + } + Ok(()) +} + +#[derive(Clone)] +struct SyncTransportManager { + endpoint: Endpoint, + healthcheck_query: Option, +} + +impl ManageConnection for SyncTransportManager { + type Connection = SyncPooledTransport; + type Error = PoolBackendError; + + fn connect(&self) -> Result { + connect_sync_transport(&self.endpoint) + .map(SyncPooledTransport::new) + .map_err(Into::into) + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + sync_validate_connection(conn, self.healthcheck_query.as_deref()) + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.is_broken() + } +} + +struct SyncPooledConnectionLease { + connection: Option>, + reusable: bool, +} + +impl SyncPooledConnectionLease { + fn new(connection: PooledConnection) -> Self { + Self { + connection: Some(connection), + reusable: false, + } + } + + fn mark_inner_broken(&mut self) { + if let Some(connection) = self.connection.as_mut() { + connection.mark_broken(); + } + } +} + +impl Read for SyncPooledConnectionLease { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.connection + .as_mut() + .expect("active pooled lease must hold a checked-out connection") + .read(buf) + } +} + +impl SyncRawLease for SyncPooledConnectionLease { + fn mark_reusable(&mut self) { + self.reusable = true; + } + + fn abandon(&mut self) { + self.reusable = false; + self.mark_inner_broken(); + } +} + +impl Drop for SyncPooledConnectionLease { + fn drop(&mut self) { + if !self.reusable { + self.mark_inner_broken(); + } + } +} + +struct SyncPoolState { + pool: Option>, + closed: bool, +} + +struct SyncPoolCore { + manager: SyncTransportManager, + options: PoolOptions, + state: Mutex, +} + +impl SyncPoolCore { + fn new(endpoint: Endpoint, options: PoolOptions) -> Result { + endpoint.validate().map_err(PythonError::Pool)?; + Ok(Self { + manager: SyncTransportManager { + endpoint, + healthcheck_query: options.healthcheck_query_value().map(str::to_string), + }, + options, + state: Mutex::new(SyncPoolState { + pool: None, + closed: false, + }), + }) + } + + fn build_pool(&self) -> Result, PoolBackendError> { + let mut builder = r2d2::Pool::builder() + .max_size(self.options.max_size_value()) + .test_on_check_out(self.options.test_on_checkout_value()) + .connection_timeout(Duration::from_millis( + self.options.checkout_timeout_ms_value(), + )); + builder = builder.min_idle(self.options.min_idle_value()); + builder = builder.idle_timeout( + self.options + .idle_timeout_ms_value() + .map(Duration::from_millis), + ); + builder = builder.max_lifetime( + self.options + .max_lifetime_ms_value() + .map(Duration::from_millis), + ); + builder + .build(self.manager.clone()) + .map_err(|error| PoolBackendError::Message(error.to_string())) + } + + fn pool(&self) -> Result, PythonError> { + let mut state = self + .state + .lock() + .map_err(|_| PythonError::Pool("pool lock is poisoned".to_string()))?; + if state.closed { + return Err(PythonError::PoolClosed); + } + if let Some(pool) = &state.pool { + return Ok(pool.clone()); + } + let pool = self + .build_pool() + .map_err(|error| PythonError::Pool(error.to_string()))?; + state.pool = Some(pool.clone()); + Ok(pool) + } + + fn close(&self) -> Result<(), PythonError> { + let mut state = self + .state + .lock() + .map_err(|_| PythonError::Pool("pool lock is poisoned".to_string()))?; + state.closed = true; + state.pool.take(); + Ok(()) + } + + fn metrics(&self) -> Result { + let state = self + .state + .lock() + .map_err(|_| PythonError::Pool("pool lock is poisoned".to_string()))?; + let (connections, idle_connections, initialized) = state + .pool + .as_ref() + .map(|pool| { + let pool_state = pool.state(); + (pool_state.connections, pool_state.idle_connections, true) + }) + .unwrap_or((0, 0, false)); + Ok(PoolMetrics::new_native( + connections, + idle_connections, + self.options.max_size_value(), + self.options.min_idle_value(), + initialized, + state.closed, + )) + } + + fn prewarm(&self) -> Result { + let target = self.options.min_idle_value().unwrap_or(1); + if target == 0 { + return self.metrics(); + } + let pool = self.pool()?; + let mut connections = Vec::with_capacity(target as usize); + for _ in 0..target { + connections.push( + pool.get() + .map_err(|error| PythonError::Pool(error.to_string()))?, + ); + } + drop(connections); + self.metrics() + } + + fn try_query_frame_once(&self, expr: &str) -> Result, PythonError> { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let pool = self.pool()?; + let mut connection = pool + .get() + .map_err(|error| PythonError::Pool(error.to_string()))?; + let frame = request_frame_streaming_over(&mut *connection, &payload).map_err(|error| { + connection.mark_broken(); + crate::errors::map_transport_error(error) + })?; + validate_response_frame(&frame).map_err(crate::errors::map_transport_error)?; + Ok(frame) + } + + fn query_frame(&self, expr: &str) -> Result, PythonError> { + for attempt in 0..=self.options.retry_attempts_value() { + match self.try_query_frame_once(expr) { + Ok(frame) => return Ok(frame), + Err(error) => { + if attempt == self.options.retry_attempts_value() + || !retryable_pool_error(&error) + { + return Err(error); + } + sync_backoff(self.options.retry_backoff_ms_value()); + } + } + } + Err(PythonError::Pool( + "connection pool retry loop exited unexpectedly".to_string(), + )) + } + + fn try_begin_raw_query_once(&self, expr: &str) -> Result { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let pool = self.pool()?; + let mut connection = pool + .get() + .map_err(|error| PythonError::Pool(error.to_string()))?; + match begin_streaming_frame_over(&mut *connection, &payload) { + Ok((header_bytes, remaining_body)) => { + let header = match response_header(header_bytes) { + Ok(header) => header, + Err(error) => { + connection.mark_broken(); + return Err(error); + } + }; + Ok(( + header, + header_bytes, + remaining_body, + Box::new(SyncPooledConnectionLease::new(connection)), + )) + } + Err(error) => { + connection.mark_broken(); + Err(crate::errors::map_transport_error(error)) + } + } + } + + fn begin_raw_query(&self, expr: &str) -> Result { + for attempt in 0..=self.options.retry_attempts_value() { + match self.try_begin_raw_query_once(expr) { + Ok(response) => return Ok(response), + Err(error) => { + if attempt == self.options.retry_attempts_value() + || !retryable_pool_error(&error) + { + return Err(error); + } + sync_backoff(self.options.retry_backoff_ms_value()); + } + } + } + Err(PythonError::Pool( + "connection pool retry loop exited unexpectedly".to_string(), + )) + } +} + +async fn async_request( + conn: &mut AsyncPooledTransport, + payload: &[u8], +) -> Result, PythonError> { + match qroissant_transport::request_frame_streaming_over_async(conn, payload).await { + Ok(frame) => Ok(frame), + Err(error) => { + conn.mark_broken(); + Err(crate::errors::map_transport_error(error)) + } + } +} + +async fn async_request_value( + conn: &mut AsyncPooledTransport, + payload: &[u8], +) -> Result { + match request_value_pipelined_over_async(conn, payload).await { + Ok(value) => Ok(value), + Err(error) => { + conn.mark_broken(); + Err(crate::errors::map_transport_error(error)) + } + } +} + +async fn async_send_query( + conn: &mut AsyncPooledTransport, + query: &str, +) -> Result, PythonError> { + let payload = encode_sync_query(query).map_err(crate::errors::map_transport_error)?; + async_request(conn, &payload).await +} + +async fn async_validate_connection( + conn: &mut AsyncPooledTransport, + healthcheck_query: Option<&str>, +) -> Result<(), PoolBackendError> { + if conn.is_broken() { + return Err(PoolBackendError::Message( + "pooled connection is broken".to_string(), + )); + } + if let Some(query) = healthcheck_query { + let response = async_send_query(conn, query).await?; + validate_success_response(&response)?; + } + Ok(()) +} + +#[derive(Clone)] +struct AsyncTransportManager { + endpoint: Endpoint, + healthcheck_query: Option, +} + +impl AsyncManageConnection for AsyncTransportManager { + type Connection = AsyncPooledTransport; + type Error = PoolBackendError; + + async fn connect(&self) -> Result { + connect_async_transport(&self.endpoint) + .await + .map(AsyncPooledTransport::new) + .map_err(Into::into) + } + + async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + async_validate_connection(conn, self.healthcheck_query.as_deref()).await + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + conn.is_broken() + } +} + +struct AsyncPooledConnectionLease { + connection: Option>, + reusable: bool, +} + +impl AsyncPooledConnectionLease { + fn new(connection: bb8::PooledConnection<'static, AsyncTransportManager>) -> Self { + Self { + connection: Some(connection), + reusable: false, + } + } + + fn mark_inner_broken(&mut self) { + if let Some(connection) = self.connection.as_mut() { + connection.mark_broken(); + } + } +} + +impl AsyncRead for AsyncPooledConnectionLease { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let connection = self + .connection + .as_mut() + .expect("active async pooled lease must hold a checked-out connection"); + Pin::new(&mut **connection).poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncPooledConnectionLease { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let connection = self + .connection + .as_mut() + .expect("active async pooled lease must hold a checked-out connection"); + Pin::new(&mut **connection).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let connection = self + .connection + .as_mut() + .expect("active async pooled lease must hold a checked-out connection"); + Pin::new(&mut **connection).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let connection = self + .connection + .as_mut() + .expect("active async pooled lease must hold a checked-out connection"); + Pin::new(&mut **connection).poll_shutdown(cx) + } +} + +impl AsyncStreamingLease for AsyncPooledConnectionLease { + fn mark_reusable(&mut self) { + self.reusable = true; + } + + fn abandon(&mut self) { + self.reusable = false; + self.mark_inner_broken(); + } +} + +impl Drop for AsyncPooledConnectionLease { + fn drop(&mut self) { + if !self.reusable { + self.mark_inner_broken(); + } + } +} + +struct AsyncPoolState { + pool: Option>, + closed: bool, +} + +struct AsyncPoolCore { + manager: AsyncTransportManager, + options: PoolOptions, + state: Mutex, +} + +impl AsyncPoolCore { + fn new(endpoint: Endpoint, options: PoolOptions) -> Result { + endpoint.validate().map_err(PythonError::Pool)?; + Ok(Self { + manager: AsyncTransportManager { + endpoint, + healthcheck_query: options.healthcheck_query_value().map(str::to_string), + }, + options, + state: Mutex::new(AsyncPoolState { + pool: None, + closed: false, + }), + }) + } + + async fn build_pool(&self) -> Result, PoolBackendError> { + let mut builder = bb8::Pool::builder() + .max_size(self.options.max_size_value()) + .connection_timeout(Duration::from_millis( + self.options.checkout_timeout_ms_value(), + )) + .test_on_check_out(self.options.test_on_checkout_value()); + builder = builder.min_idle(self.options.min_idle_value()); + builder = builder.idle_timeout( + self.options + .idle_timeout_ms_value() + .map(Duration::from_millis), + ); + builder = builder.max_lifetime( + self.options + .max_lifetime_ms_value() + .map(Duration::from_millis), + ); + builder + .build(self.manager.clone()) + .await + .map_err(|error| PoolBackendError::Message(error.to_string())) + } + + async fn pool(&self) -> Result, PythonError> { + { + let state = self + .state + .lock() + .map_err(|_| PythonError::Pool("pool lock is poisoned".to_string()))?; + if state.closed { + return Err(PythonError::PoolClosed); + } + if let Some(pool) = &state.pool { + return Ok(pool.clone()); + } + } + + let built = self + .build_pool() + .await + .map_err(|error| PythonError::Pool(error.to_string()))?; + let mut state = self + .state + .lock() + .map_err(|_| PythonError::Pool("pool lock is poisoned".to_string()))?; + if state.closed { + return Err(PythonError::PoolClosed); + } + if let Some(pool) = &state.pool { + return Ok(pool.clone()); + } + state.pool = Some(built.clone()); + Ok(built) + } + + fn close(&self) -> Result<(), PythonError> { + let mut state = self + .state + .lock() + .map_err(|_| PythonError::Pool("pool lock is poisoned".to_string()))?; + state.closed = true; + state.pool.take(); + Ok(()) + } + + fn metrics(&self) -> Result { + let state = self + .state + .lock() + .map_err(|_| PythonError::Pool("pool lock is poisoned".to_string()))?; + let (connections, idle_connections, initialized) = state + .pool + .as_ref() + .map(|pool| { + let pool_state = pool.state(); + (pool_state.connections, pool_state.idle_connections, true) + }) + .unwrap_or((0, 0, false)); + Ok(PoolMetrics::new_native( + connections, + idle_connections, + self.options.max_size_value(), + self.options.min_idle_value(), + initialized, + state.closed, + )) + } + + async fn prewarm(&self) -> Result { + let target = self.options.min_idle_value().unwrap_or(1); + if target == 0 { + return self.metrics(); + } + let pool = self.pool().await?; + let mut connections = Vec::with_capacity(target as usize); + for _ in 0..target { + connections.push( + pool.get() + .await + .map_err(|error| PythonError::Pool(error.to_string()))?, + ); + } + drop(connections); + self.metrics() + } + + async fn try_query_value_once(&self, expr: &str) -> Result { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let pool = self.pool().await?; + let mut connection = pool + .get() + .await + .map_err(|error| PythonError::Pool(error.to_string()))?; + async_request_value(&mut connection, &payload).await + } + + async fn query_value(&self, expr: &str) -> Result { + for attempt in 0..=self.options.retry_attempts_value() { + match self.try_query_value_once(expr).await { + Ok(value) => return Ok(value), + Err(error) => { + if attempt == self.options.retry_attempts_value() + || !retryable_pool_error(&error) + { + return Err(error); + } + async_backoff(self.options.retry_backoff_ms_value()).await; + } + } + } + Err(PythonError::Pool( + "connection pool retry loop exited unexpectedly".to_string(), + )) + } + + async fn try_begin_raw_query_once(&self, expr: &str) -> Result { + let payload = encode_sync_query(expr).map_err(crate::errors::map_transport_error)?; + let pool = self.pool().await?; + let mut connection = pool + .get_owned() + .await + .map_err(|error| PythonError::Pool(error.to_string()))?; + match begin_streaming_frame_over_async(&mut *connection, &payload).await { + Ok((header_bytes, remaining_body)) => { + let header = match response_header(header_bytes) { + Ok(header) => header, + Err(error) => { + connection.mark_broken(); + return Err(error); + } + }; + Ok(( + header, + header_bytes, + remaining_body, + AsyncPooledConnectionLease::new(connection), + )) + } + Err(error) => { + connection.mark_broken(); + Err(crate::errors::map_transport_error(error)) + } + } + } + + async fn begin_raw_query(&self, expr: &str) -> Result { + for attempt in 0..=self.options.retry_attempts_value() { + match self.try_begin_raw_query_once(expr).await { + Ok(response) => return Ok(response), + Err(error) => { + if attempt == self.options.retry_attempts_value() + || !retryable_pool_error(&error) + { + return Err(error); + } + async_backoff(self.options.retry_backoff_ms_value()).await; + } + } + } + Err(PythonError::Pool( + "connection pool retry loop exited unexpectedly".to_string(), + )) + } +} + +#[pyclass(module = "qroissant")] +pub struct Connection { + core: Arc, + options: Option, +} + +#[pymethods] +impl Connection { + #[new] + #[pyo3(signature = (endpoint, *, options=None))] + fn new( + py: Python<'_>, + endpoint: PyRef<'_, Endpoint>, + options: Option>, + ) -> PyResult { + let endpoint = endpoint.clone(); + let core = py + .detach(move || ConnectionCore::connect(&endpoint)) + .map_err(to_py_err)?; + Ok(Self { + core: Arc::new(core), + options: options.map(|value| value.clone()), + }) + } + + fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __exit__( + &self, + _exc_type: Option<&Bound<'_, PyAny>>, + _exc_val: Option<&Bound<'_, PyAny>>, + _exc_tb: Option<&Bound<'_, PyAny>>, + ) -> PyResult<()> { + self.core.close().map_err(to_py_err) + } + + fn close(&self, py: Python<'_>) -> PyResult<()> { + py.detach(|| self.core.close()).map_err(to_py_err) + } + + #[pyo3(signature = (expr, /, *, raw=false, decode=None))] + fn query( + &self, + py: Python<'_>, + expr: String, + raw: bool, + decode: Option>, + ) -> PyResult> { + if raw { + let (header, header_bytes, remaining_body, lease) = py + .detach(|| self.core.begin_raw_query(&expr)) + .map_err(to_py_err)?; + let response = RawResponse::streaming(header, header_bytes, remaining_body, lease); + return Ok(Py::new(py, response)?.into_any()); + } + + let payload = py + .detach(|| self.core.query_frame(&expr)) + .map_err(to_py_err)?; + let options = decode.as_deref().or(self.options.as_ref()); + let (value, opts) = + decode_core_value(bytes::Bytes::from(payload), options).map_err(to_py_err)?; + core_value_to_python_with_opts(py, value, opts) + } +} + +#[pyclass(module = "qroissant")] +pub struct AsyncConnection { + core: Arc, + options: Option, +} + +#[pymethods] +impl AsyncConnection { + #[new] + #[pyo3(signature = (endpoint, *, options=None))] + fn new( + endpoint: PyRef<'_, Endpoint>, + options: Option>, + ) -> PyResult { + Ok(Self { + core: Arc::new(AsyncConnectionCore::new(endpoint.clone()).map_err(to_py_err)?), + options: options.map(|value| value.clone()), + }) + } + + fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult> { + let core = slf.core.clone(); + let options = slf.options.clone(); + async_self(py, move |py| Py::new(py, Self { core, options })) + } + + fn __aexit__<'py>( + &self, + py: Python<'py>, + _exc_type: Option<&Bound<'_, PyAny>>, + _exc_val: Option<&Bound<'_, PyAny>>, + _exc_tb: Option<&Bound<'_, PyAny>>, + ) -> PyResult> { + let core = self.core.clone(); + future_into_py(py, async move { + core.close().await.map_err(to_py_err)?; + Ok(false) + }) + } + + fn close<'py>(&self, py: Python<'py>) -> PyResult> { + let core = self.core.clone(); + future_into_py(py, async move { + core.close().await.map_err(to_py_err)?; + Ok(()) + }) + } + + #[pyo3(signature = (expr, /, *, raw=false, decode=None))] + fn query<'py>( + &self, + py: Python<'py>, + expr: String, + raw: bool, + decode: Option, + ) -> PyResult> { + let core = self.core.clone(); + let default_options = self.options.clone(); + future_into_py(py, async move { + if raw { + let (header, header_bytes, remaining_body, lease) = + core.begin_raw_query(&expr).await.map_err(to_py_err)?; + return Python::attach(|py| { + let response = RawResponse::streaming( + header, + header_bytes, + remaining_body, + Box::new(BlockingAsyncBridge::new(lease)), + ); + Py::new(py, response.into_async()).map(|value| value.into_any()) + }); + } + + let value = core.query_value(&expr).await.map_err(to_py_err)?; + let options = decode.as_ref().or(default_options.as_ref()); + let proj_opts = crate::serde::decode_options_to_proj_opts(options); + Python::attach(|py| core_value_to_python_with_opts(py, value, proj_opts)) + }) + } +} + +#[pyclass(module = "qroissant")] +pub struct Pool { + core: Arc, + options: Option, +} + +#[pymethods] +impl Pool { + #[new] + #[pyo3(signature = (endpoint, *, options=None, pool=None))] + fn new( + endpoint: PyRef<'_, Endpoint>, + options: Option>, + pool: Option>, + ) -> PyResult { + let core = SyncPoolCore::new( + endpoint.clone(), + pool.map(|value| value.clone()).unwrap_or_default(), + ) + .map_err(to_py_err)?; + Ok(Self { + core: Arc::new(core), + options: options.map(|value| value.clone()), + }) + } + + fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __exit__( + &self, + _exc_type: Option<&Bound<'_, PyAny>>, + _exc_val: Option<&Bound<'_, PyAny>>, + _exc_tb: Option<&Bound<'_, PyAny>>, + ) -> PyResult<()> { + self.core.close().map_err(to_py_err) + } + + fn close(&self, py: Python<'_>) -> PyResult<()> { + py.detach(|| self.core.close()).map_err(to_py_err) + } + + fn prewarm(&self, py: Python<'_>) -> PyResult { + py.detach(|| self.core.prewarm()).map_err(to_py_err) + } + + fn metrics(&self) -> PyResult { + self.core.metrics().map_err(to_py_err) + } + + #[pyo3(signature = (expr, /, *, raw=false, decode=None))] + fn query( + &self, + py: Python<'_>, + expr: String, + raw: bool, + decode: Option>, + ) -> PyResult> { + if raw { + let (header, header_bytes, remaining_body, lease) = py + .detach(|| self.core.begin_raw_query(&expr)) + .map_err(to_py_err)?; + let response = RawResponse::streaming(header, header_bytes, remaining_body, lease); + return Ok(Py::new(py, response)?.into_any()); + } + + let payload = py + .detach(|| self.core.query_frame(&expr)) + .map_err(to_py_err)?; + let options = decode.as_deref().or(self.options.as_ref()); + let (value, opts) = + decode_core_value(bytes::Bytes::from(payload), options).map_err(to_py_err)?; + core_value_to_python_with_opts(py, value, opts) + } +} + +#[pyclass(module = "qroissant")] +pub struct AsyncPool { + core: Arc, + options: Option, +} + +#[pymethods] +impl AsyncPool { + #[new] + #[pyo3(signature = (endpoint, *, options=None, pool=None))] + fn new( + endpoint: PyRef<'_, Endpoint>, + options: Option>, + pool: Option>, + ) -> PyResult { + let core = AsyncPoolCore::new( + endpoint.clone(), + pool.map(|value| value.clone()).unwrap_or_default(), + ) + .map_err(to_py_err)?; + Ok(Self { + core: Arc::new(core), + options: options.map(|value| value.clone()), + }) + } + + fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult> { + let core = slf.core.clone(); + let options = slf.options.clone(); + async_self(py, move |py| Py::new(py, Self { core, options })) + } + + fn __aexit__<'py>( + &self, + py: Python<'py>, + _exc_type: Option<&Bound<'_, PyAny>>, + _exc_val: Option<&Bound<'_, PyAny>>, + _exc_tb: Option<&Bound<'_, PyAny>>, + ) -> PyResult> { + let core = self.core.clone(); + future_into_py(py, async move { + core.close().map_err(to_py_err)?; + Ok(false) + }) + } + + fn close<'py>(&self, py: Python<'py>) -> PyResult> { + let core = self.core.clone(); + future_into_py(py, async move { + core.close().map_err(to_py_err)?; + Ok(()) + }) + } + + fn prewarm<'py>(&self, py: Python<'py>) -> PyResult> { + let core = self.core.clone(); + future_into_py(py, async move { + let metrics = core.prewarm().await.map_err(to_py_err)?; + Python::attach(|py| Py::new(py, metrics).map(|value| value.into_any())) + }) + } + + fn metrics<'py>(&self, py: Python<'py>) -> PyResult> { + let core = self.core.clone(); + future_into_py(py, async move { + let metrics = core.metrics().map_err(to_py_err)?; + Python::attach(|py| Py::new(py, metrics).map(|value| value.into_any())) + }) + } + + #[pyo3(signature = (expr, /, *, raw=false, decode=None))] + fn query<'py>( + &self, + py: Python<'py>, + expr: String, + raw: bool, + decode: Option, + ) -> PyResult> { + let core = self.core.clone(); + let default_options = self.options.clone(); + future_into_py(py, async move { + if raw { + let (header, header_bytes, remaining_body, lease) = + core.begin_raw_query(&expr).await.map_err(to_py_err)?; + return Python::attach(|py| { + let response = RawResponse::streaming( + header, + header_bytes, + remaining_body, + Box::new(BlockingAsyncBridge::new(lease)), + ); + Py::new(py, response.into_async()).map(|value| value.into_any()) + }); + } + + let value = core.query_value(&expr).await.map_err(to_py_err)?; + let options = decode.as_ref().or(default_options.as_ref()); + let proj_opts = crate::serde::decode_options_to_proj_opts(options); + Python::attach(|py| core_value_to_python_with_opts(py, value, proj_opts)) + }) + } +} + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + Ok(()) +} diff --git a/crates/qroissant-python/src/errors.rs b/crates/qroissant-python/src/errors.rs new file mode 100644 index 0000000..0edb00c --- /dev/null +++ b/crates/qroissant-python/src/errors.rs @@ -0,0 +1,114 @@ +use pyo3::create_exception; +use pyo3::exceptions::PyException; +use pyo3::exceptions::PyNotImplementedError; +use pyo3::prelude::*; +use pyo3::types::PyModule; +use qroissant_transport::TransportError; +use thiserror::Error; + +create_exception!( + qroissant, + QroissantError, + PyException, + "Base exception for qroissant errors." +); +create_exception!( + qroissant, + DecodeError, + QroissantError, + "Raised when q IPC payload decoding fails." +); +create_exception!( + qroissant, + ProtocolError, + QroissantError, + "Raised when q IPC framing or protocol validation fails." +); +create_exception!( + qroissant, + TransportErrorPy, + QroissantError, + "Raised when transport IO or socket operations fail." +); +create_exception!( + qroissant, + OperationError, + QroissantError, + "Raised when an operation is unsupported in the current state." +); +create_exception!( + qroissant, + QRuntimeError, + QroissantError, + "Raised when the remote q process returns an error response." +); +create_exception!( + qroissant, + PoolError, + QroissantError, + "Raised when connection pool management fails." +); +create_exception!( + qroissant, + PoolClosedError, + PoolError, + "Raised when a closed pool is used." +); + +#[derive(Debug, Error)] +pub enum PythonError { + #[error("{0}")] + Decode(String), + #[error("{0}")] + Protocol(String), + #[error("{0}")] + Transport(String), + #[error("{0}")] + Operation(String), + #[error("{0}")] + QRuntime(String), + #[error("{0}")] + Pool(String), + #[error("connection pool is closed")] + PoolClosed, + #[error("{0}")] + NotImplemented(String), +} + +pub type PythonResult = Result; + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + let py = module.py(); + module.add("QroissantError", py.get_type::())?; + module.add("DecodeError", py.get_type::())?; + module.add("ProtocolError", py.get_type::())?; + module.add("TransportError", py.get_type::())?; + module.add("OperationError", py.get_type::())?; + module.add("QRuntimeError", py.get_type::())?; + module.add("PoolError", py.get_type::())?; + module.add("PoolClosedError", py.get_type::())?; + Ok(()) +} + +pub fn to_py_err(error: PythonError) -> PyErr { + match error { + PythonError::Decode(message) => DecodeError::new_err(message), + PythonError::Protocol(message) => ProtocolError::new_err(message), + PythonError::Transport(message) => TransportErrorPy::new_err(message), + PythonError::Operation(message) => OperationError::new_err(message), + PythonError::QRuntime(message) => QRuntimeError::new_err(message), + PythonError::Pool(message) => PoolError::new_err(message), + PythonError::PoolClosed => PoolClosedError::new_err("connection pool is closed"), + PythonError::NotImplemented(message) => PyNotImplementedError::new_err(message), + } +} + +pub fn map_transport_error(error: TransportError) -> PythonError { + match error { + TransportError::Closed => PythonError::Operation(error.to_string()), + TransportError::Protocol(_) => PythonError::Protocol(error.to_string()), + TransportError::Io(_) + | TransportError::InvalidEndpoint(_) + | TransportError::InvalidQueryLength(_) => PythonError::Transport(error.to_string()), + } +} diff --git a/crates/qroissant-python/src/lib.rs b/crates/qroissant-python/src/lib.rs new file mode 100644 index 0000000..4369a26 --- /dev/null +++ b/crates/qroissant-python/src/lib.rs @@ -0,0 +1,28 @@ +#![allow(deprecated)] + +//! Native Python module for qroissant. + +mod client; +mod errors; +mod raw_response; +mod repr; +mod serde; +mod types; +mod values; + +use pyo3::prelude::*; +use pyo3::types::PyModule; + +#[pymodule] +fn _native(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add("__doc__", "Native qroissant extension")?; + module.add("__version__", env!("CARGO_PKG_VERSION"))?; + errors::register(module)?; + types::register(module)?; + repr::register(module)?; + values::register(module)?; + raw_response::register(module)?; + client::register(module)?; + serde::register(module)?; + Ok(()) +} diff --git a/crates/qroissant-python/src/raw_response.rs b/crates/qroissant-python/src/raw_response.rs new file mode 100644 index 0000000..56ec08c --- /dev/null +++ b/crates/qroissant-python/src/raw_response.rs @@ -0,0 +1,777 @@ +use std::fmt; +use std::io::Read; +use std::sync::Arc; +use std::sync::Mutex; +use std::sync::MutexGuard; + +use pyo3::buffer::PyBuffer; +use pyo3::prelude::*; +use pyo3::types::PyAny; +use pyo3::types::PyBytes; +use pyo3::types::PyModule; +use pyo3_async_runtimes::tokio::future_into_py; +use qroissant_core::HEADER_LEN; +use qroissant_core::MessageHeader as CoreMessageHeader; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWrite; +use tokio::task::spawn_blocking; + +use crate::serde::decode_core_value; +use crate::types::Compression; +use crate::types::DecodeOptions; +use crate::types::Encoding; +use crate::types::MessageHeader; +use crate::types::MessageType; +use crate::values::core_value_to_python_with_opts; + +pub(crate) trait SyncRawLease: Read + Send { + fn mark_reusable(&mut self); + fn abandon(&mut self); +} + +pub(crate) trait AsyncStreamingLease: AsyncRead + AsyncWrite + Send + Unpin { + fn mark_reusable(&mut self); + fn abandon(&mut self); +} + +pub(crate) struct BlockingAsyncBridge { + inner: T, + handle: tokio::runtime::Handle, +} + +impl BlockingAsyncBridge { + pub(crate) fn new(inner: T) -> Self { + Self { + inner, + handle: tokio::runtime::Handle::current(), + } + } +} + +impl Read for BlockingAsyncBridge +where + T: AsyncStreamingLease, +{ + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let handle = self.handle.clone(); + let inner = &mut self.inner; + let fut = async move { inner.read(buf).await }; + if tokio::runtime::Handle::try_current().is_ok() { + tokio::task::block_in_place(|| handle.block_on(fut)) + } else { + handle.block_on(fut) + } + } +} + +impl SyncRawLease for BlockingAsyncBridge +where + T: AsyncStreamingLease, +{ + fn mark_reusable(&mut self) { + self.inner.mark_reusable(); + } + + fn abandon(&mut self) { + self.inner.abandon(); + } +} + +fn closed_raw_response_error() -> PyErr { + pyo3::exceptions::PyValueError::new_err("I/O operation on closed qroissant raw response") +} + +fn backend_lock_error() -> PyErr { + pyo3::exceptions::PyRuntimeError::new_err("qroissant raw response state is poisoned") +} + +fn unsupported_seek_error() -> PyErr { + pyo3::exceptions::PyOSError::new_err( + "qroissant raw streaming responses are forward-only and do not support seek()", + ) +} + +fn readonly_buffer_error() -> PyErr { + pyo3::exceptions::PyTypeError::new_err("readinto() requires a writable buffer") +} + +fn non_contiguous_buffer_error() -> PyErr { + pyo3::exceptions::PyTypeError::new_err("readinto() requires a C-contiguous buffer") +} + +#[derive(Debug)] +enum RawReadError { + Closed, + BackendPoisoned, + PartiallyConsumed, + Io(std::io::Error), +} + +impl From for RawReadError { + fn from(error: std::io::Error) -> Self { + Self::Io(error) + } +} + +impl fmt::Display for RawReadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Closed => write!(f, "raw response is closed"), + Self::BackendPoisoned => write!(f, "raw response backend is poisoned"), + Self::PartiallyConsumed => { + write!(f, "raw response has already been partially consumed") + } + Self::Io(error) => error.fmt(f), + } + } +} + +fn raw_read_error_to_py(error: RawReadError) -> PyErr { + match error { + RawReadError::Closed => closed_raw_response_error(), + RawReadError::BackendPoisoned => backend_lock_error(), + RawReadError::PartiallyConsumed => pyo3::exceptions::PyValueError::new_err( + "cannot decode a partially consumed raw response", + ), + RawReadError::Io(error) => PyErr::from(error), + } +} + +fn extract_writable_contiguous_u8_buffer(payload: &Bound<'_, PyAny>) -> PyResult> { + let buffer = PyBuffer::::get(payload)?; + if buffer.readonly() { + return Err(readonly_buffer_error()); + } + if !buffer.is_c_contiguous() { + return Err(non_contiguous_buffer_error()); + } + Ok(buffer) +} + +enum RawResponseBackend { + Buffered { + payload: Vec, + position: usize, + }, + Streaming { + header_bytes: [u8; HEADER_LEN], + header_position: usize, + remaining_body: usize, + position: usize, + lease: Option>, + }, + Closed, +} + +struct RawResponseState { + header: MessageHeader, + backend: RawResponseBackend, +} + +impl RawResponseState { + fn streaming_remaining_total(header_position: usize, remaining_body: usize) -> usize { + (HEADER_LEN - header_position) + remaining_body + } + + fn finalize_stream(lease: &mut Option>, reusable: bool) { + if let Some(mut lease) = lease.take() { + if reusable { + lease.mark_reusable(); + } else { + lease.abandon(); + } + } + } + + fn close_backend(backend: &mut RawResponseBackend) { + let backend = std::mem::replace(backend, RawResponseBackend::Closed); + match backend { + RawResponseBackend::Buffered { .. } | RawResponseBackend::Closed => {} + RawResponseBackend::Streaming { + remaining_body, + header_position, + mut lease, + .. + } => { + let reusable = + Self::streaming_remaining_total(header_position, remaining_body) == 0; + Self::finalize_stream(&mut lease, reusable); + } + } + } + + fn read_streaming_into( + header_bytes: &[u8; HEADER_LEN], + header_position: &mut usize, + remaining_body: &mut usize, + position: &mut usize, + lease: &mut Option>, + out: &mut [u8], + ) -> Result { + let total_remaining = Self::streaming_remaining_total(*header_position, *remaining_body); + if total_remaining == 0 { + Self::finalize_stream(lease, true); + return Ok(0); + } + + let target = out.len().min(total_remaining); + let mut filled = 0_usize; + let header_copied = if *header_position < HEADER_LEN && filled < target { + let available = HEADER_LEN - *header_position; + let to_copy = (target - filled).min(available); + out[..to_copy] + .copy_from_slice(&header_bytes[*header_position..*header_position + to_copy]); + *header_position += to_copy; + filled += to_copy; + to_copy + } else { + 0 + }; + + if filled < target { + while filled < target { + let lease_ref = lease + .as_mut() + .expect("streaming raw responses must hold an active lease"); + let read = lease_ref.read(&mut out[filled..target])?; + if read == 0 { + Self::finalize_stream(lease, false); + return Err(std::io::Error::from(std::io::ErrorKind::UnexpectedEof).into()); + } + filled += read; + } + } + + let body_bytes = filled.saturating_sub(header_copied); + if body_bytes != 0 { + *remaining_body = remaining_body.saturating_sub(body_bytes); + } + *position = position.saturating_add(filled); + + if Self::streaming_remaining_total(*header_position, *remaining_body) == 0 { + Self::finalize_stream(lease, true); + } + Ok(filled) + } +} + +fn header_from_payload(payload: &[u8]) -> PyResult { + if payload.len() < HEADER_LEN { + return Ok(MessageHeader::new_native( + Encoding::LittleEndian, + MessageType::Response, + Compression::Uncompressed, + payload.len(), + )); + } + let header = CoreMessageHeader::parse(payload) + .map_err(|error| pyo3::exceptions::PyValueError::new_err(error.to_string()))?; + Ok(MessageHeader::from(header)) +} + +#[pyclass(module = "qroissant")] +pub struct RawResponse { + state: Arc>, +} + +impl RawResponse { + fn lock_state_result(&self) -> Result, RawReadError> { + self.state.lock().map_err(|_| RawReadError::BackendPoisoned) + } + + fn lock_state(&self) -> PyResult> { + self.lock_state_result().map_err(raw_read_error_to_py) + } + + fn ensure_open(backend: &RawResponseBackend) -> PyResult<()> { + if matches!(backend, RawResponseBackend::Closed) { + return Err(closed_raw_response_error()); + } + Ok(()) + } + + fn ensure_open_result(backend: &RawResponseBackend) -> Result<(), RawReadError> { + if matches!(backend, RawResponseBackend::Closed) { + return Err(RawReadError::Closed); + } + Ok(()) + } + + pub(crate) fn buffered(payload: Vec) -> PyResult { + let header = header_from_payload(&payload)?; + Ok(Self { + state: Arc::new(Mutex::new(RawResponseState { + header, + backend: RawResponseBackend::Buffered { + payload, + position: 0, + }, + })), + }) + } + + pub(crate) fn streaming( + header: MessageHeader, + header_bytes: [u8; HEADER_LEN], + remaining_body: usize, + lease: Box, + ) -> Self { + Self { + state: Arc::new(Mutex::new(RawResponseState { + header, + backend: RawResponseBackend::Streaming { + header_bytes, + header_position: 0, + remaining_body, + position: 0, + lease: Some(lease), + }, + })), + } + } + + pub(crate) fn into_async(self) -> AsyncRawResponse { + let this = std::mem::ManuallyDrop::new(self); + // SAFETY: `ManuallyDrop` suppresses `RawResponse::drop`, so it is safe + // to move the owned `Arc` into the async wrapper without closing the + // underlying raw-response state. + let state = unsafe { std::ptr::read(&this.state) }; + AsyncRawResponse { state } + } + + fn materialize_result(&self) -> Result, RawReadError> { + let position = { + let state = self.lock_state_result()?; + Self::ensure_open_result(&state.backend)?; + match &state.backend { + RawResponseBackend::Buffered { position, .. } + | RawResponseBackend::Streaming { position, .. } => *position, + RawResponseBackend::Closed => { + unreachable!("closed raw responses are handled above") + } + } + }; + if position != 0 { + return Err(RawReadError::PartiallyConsumed); + } + self.read_owned_result(None) + } + + fn read_owned_result(&self, size: Option) -> Result, RawReadError> { + let mut state = self.lock_state_result()?; + Self::ensure_open_result(&state.backend)?; + match &mut state.backend { + RawResponseBackend::Buffered { payload, position } => { + if *position >= payload.len() { + return Ok(Vec::new()); + } + let remaining = payload.len() - *position; + let to_read = match size { + Some(size) if size >= 0 => remaining.min(size as usize), + _ => remaining, + }; + let start = *position; + let end = start + to_read; + *position = end; + Ok(payload[start..end].to_vec()) + } + RawResponseBackend::Streaming { + header_bytes, + header_position, + remaining_body, + position, + lease, + } => { + let total_remaining = + RawResponseState::streaming_remaining_total(*header_position, *remaining_body); + let target = match size { + Some(size) if size >= 0 => total_remaining.min(size as usize), + _ => total_remaining, + }; + let mut out = vec![0_u8; target]; + match RawResponseState::read_streaming_into( + header_bytes, + header_position, + remaining_body, + position, + lease, + &mut out, + ) { + Ok(filled) => { + out.truncate(filled); + Ok(out) + } + Err(error) => { + state.backend = RawResponseBackend::Closed; + Err(error) + } + } + } + RawResponseBackend::Closed => Err(RawReadError::Closed), + } + } + + fn read_into_result(&self, out: &mut [u8]) -> Result { + let mut state = self.lock_state_result()?; + Self::ensure_open_result(&state.backend)?; + match &mut state.backend { + RawResponseBackend::Buffered { payload, position } => { + if *position >= payload.len() { + return Ok(0); + } + let remaining = payload.len() - *position; + let to_read = remaining.min(out.len()); + let start = *position; + let end = start + to_read; + out[..to_read].copy_from_slice(&payload[start..end]); + *position = end; + Ok(to_read) + } + RawResponseBackend::Streaming { + header_bytes, + header_position, + remaining_body, + position, + lease, + } => match RawResponseState::read_streaming_into( + header_bytes, + header_position, + remaining_body, + position, + lease, + out, + ) { + Ok(filled) => Ok(filled), + Err(error) => { + state.backend = RawResponseBackend::Closed; + Err(error) + } + }, + RawResponseBackend::Closed => Err(RawReadError::Closed), + } + } +} + +impl Drop for RawResponse { + fn drop(&mut self) { + // Clean up even if the mutex is poisoned (panic in another thread). + let mut state = match self.state.lock() { + Ok(guard) => guard, + Err(poisoned) => poisoned.into_inner(), + }; + RawResponseState::close_backend(&mut state.backend); + } +} + +#[pymethods] +impl RawResponse { + #[new] + fn new(payload: Vec) -> PyResult { + Self::buffered(payload) + } + + fn __enter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __exit__( + &self, + _exc_type: Option<&Bound<'_, PyAny>>, + _exc_val: Option<&Bound<'_, PyAny>>, + _exc_tb: Option<&Bound<'_, PyAny>>, + ) -> PyResult<()> { + self.close() + } + + #[getter] + fn closed(&self) -> bool { + self.state + .lock() + .map(|state| matches!(state.backend, RawResponseBackend::Closed)) + .unwrap_or(true) + } + + #[getter] + fn header(&self) -> PyResult { + let state = self.lock_state()?; + Ok(state.header.clone()) + } + + fn close(&self) -> PyResult<()> { + let mut state = self.lock_state()?; + RawResponseState::close_backend(&mut state.backend); + Ok(()) + } + + fn readable(&self) -> bool { + !self.closed() + } + + fn seekable(&self) -> bool { + self.state + .lock() + .map(|state| matches!(state.backend, RawResponseBackend::Buffered { .. })) + .unwrap_or(false) + } + + #[pyo3(signature = (size=None))] + fn read<'py>(&self, py: Python<'py>, size: Option) -> PyResult> { + let bytes = py + .detach(|| self.read_owned_result(size)) + .map_err(raw_read_error_to_py)?; + Ok(PyBytes::new(py, &bytes)) + } + + #[pyo3(signature = (size=None))] + fn read1<'py>(&self, py: Python<'py>, size: Option) -> PyResult> { + self.read(py, size) + } + + fn readinto(&self, py: Python<'_>, buffer: &Bound<'_, PyAny>) -> PyResult { + let writable = extract_writable_contiguous_u8_buffer(buffer)?; + let len = writable.len_bytes(); + if len == 0 { + let mut empty = []; + return py + .detach(|| self.read_into_result(&mut empty)) + .map_err(raw_read_error_to_py); + } + let ptr = writable.buf_ptr() as usize; + py.detach(move || { + let ptr = ptr as *mut u8; + // SAFETY: the writable Python buffer outlives this detached call and + // the slice length is bounded by the exported buffer length. + let slice = unsafe { std::slice::from_raw_parts_mut(ptr, len) }; + self.read_into_result(slice) + }) + .map_err(raw_read_error_to_py) + } + + fn readinto1(&self, py: Python<'_>, buffer: &Bound<'_, PyAny>) -> PyResult { + self.readinto(py, buffer) + } + + fn tell(&self) -> PyResult { + let state = self.lock_state()?; + Self::ensure_open(&state.backend)?; + match &state.backend { + RawResponseBackend::Buffered { position, .. } + | RawResponseBackend::Streaming { position, .. } => Ok(*position), + RawResponseBackend::Closed => Err(closed_raw_response_error()), + } + } + + #[pyo3(signature = (offset, whence=0))] + fn seek(&self, offset: i64, whence: i32) -> PyResult { + let mut state = self.lock_state()?; + Self::ensure_open(&state.backend)?; + match &mut state.backend { + RawResponseBackend::Buffered { payload, position } => { + let base = match whence { + 0 => 0_i64, + 1 => i64::try_from(*position).map_err(|_| { + pyo3::exceptions::PyOverflowError::new_err( + "raw response position exceeds supported seek range", + ) + })?, + 2 => i64::try_from(payload.len()).map_err(|_| { + pyo3::exceptions::PyOverflowError::new_err( + "raw response length exceeds supported seek range", + ) + })?, + _ => { + return Err(pyo3::exceptions::PyValueError::new_err(format!( + "invalid seek whence value {whence}; expected 0, 1, or 2" + ))); + } + }; + let position_i64 = base.checked_add(offset).ok_or_else(|| { + pyo3::exceptions::PyOverflowError::new_err( + "raw response seek position overflowed", + ) + })?; + if position_i64 < 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "negative seek position is not allowed", + )); + } + *position = usize::try_from(position_i64).map_err(|_| { + pyo3::exceptions::PyOverflowError::new_err( + "raw response seek position overflowed", + ) + })?; + Ok(*position) + } + RawResponseBackend::Streaming { .. } => Err(unsupported_seek_error()), + RawResponseBackend::Closed => Err(closed_raw_response_error()), + } + } + + #[pyo3(signature = (*, options=None))] + fn decode(&self, py: Python<'_>, options: Option<&DecodeOptions>) -> PyResult> { + let payload = py + .detach(|| self.materialize_result()) + .map_err(raw_read_error_to_py)?; + let (value, opts) = + decode_core_value(bytes::Bytes::from(payload), options) + .map_err(crate::errors::to_py_err)?; + core_value_to_python_with_opts(py, value, opts) + } + + fn __repr__(&self) -> String { + match self.state.lock() { + Ok(state) => match &state.backend { + RawResponseBackend::Buffered { payload, position } => format!( + "RawResponse(mode='buffered', len={}, position={}, closed=false)", + payload.len(), + position + ), + RawResponseBackend::Streaming { + header_position, + remaining_body, + position, + .. + } => format!( + "RawResponse(mode='streaming', remaining={}, position={}, closed=false)", + RawResponseState::streaming_remaining_total(*header_position, *remaining_body), + position + ), + RawResponseBackend::Closed => "RawResponse(mode='closed', closed=true)".to_string(), + }, + Err(_) => "RawResponse(mode='poisoned', closed=true)".to_string(), + } + } +} + +#[pyclass(module = "qroissant")] +pub struct AsyncRawResponse { + state: Arc>, +} + +#[pymethods] +impl AsyncRawResponse { + fn __aenter__<'py>(slf: PyRef<'py, Self>, py: Python<'py>) -> PyResult> { + let state = slf.state.clone(); + future_into_py(py, async move { + Python::attach(|py| Py::new(py, Self { state }).map(|value| value.into_any())) + }) + } + + fn __aexit__<'py>( + &self, + py: Python<'py>, + _exc_type: Option<&Bound<'_, PyAny>>, + _exc_val: Option<&Bound<'_, PyAny>>, + _exc_tb: Option<&Bound<'_, PyAny>>, + ) -> PyResult> { + let state = self.state.clone(); + future_into_py(py, async move { + let mut state = state.lock().map_err(|_| backend_lock_error())?; + RawResponseState::close_backend(&mut state.backend); + Ok(false) + }) + } + + #[getter] + fn closed(&self) -> bool { + self.state + .lock() + .map(|state| matches!(state.backend, RawResponseBackend::Closed)) + .unwrap_or(true) + } + + #[getter] + fn header(&self) -> PyResult { + let state = self.state.lock().map_err(|_| backend_lock_error())?; + Ok(state.header.clone()) + } + + fn close<'py>(&self, py: Python<'py>) -> PyResult> { + let state = self.state.clone(); + future_into_py(py, async move { + let mut state = state.lock().map_err(|_| backend_lock_error())?; + RawResponseState::close_backend(&mut state.backend); + Ok(()) + }) + } + + #[pyo3(signature = (size=None))] + fn read<'py>(&self, py: Python<'py>, size: Option) -> PyResult> { + let raw = RawResponse { + state: self.state.clone(), + }; + future_into_py(py, async move { + let bytes = spawn_blocking(move || raw.read_owned_result(size)) + .await + .map_err(|error| pyo3::exceptions::PyRuntimeError::new_err(error.to_string()))? + .map_err(raw_read_error_to_py)?; + Python::attach(|py| Ok(PyBytes::new(py, &bytes).unbind().into_any())) + }) + } + + #[pyo3(signature = (size=None))] + fn read1<'py>(&self, py: Python<'py>, size: Option) -> PyResult> { + self.read(py, size) + } + + fn readinto<'py>(&self, py: Python<'py>, buffer: Py) -> PyResult> { + let state = self.state.clone(); + future_into_py(py, async move { + Python::attach(|py| { + let buffer = buffer.bind(py); + let writable = extract_writable_contiguous_u8_buffer(buffer)?; + let len = writable.len_bytes(); + if len == 0 { + return Ok(0); + } + let ptr = writable.buf_ptr() as usize; + let raw = RawResponse { + state: state.clone(), + }; + drop(writable); + let read = py + .detach(move || { + let ptr = ptr as *mut u8; + // SAFETY: the writable Python buffer outlives this detached call and + // the slice length is bounded by the exported buffer length. + let slice = unsafe { std::slice::from_raw_parts_mut(ptr, len) }; + raw.read_into_result(slice) + }) + .map_err(raw_read_error_to_py)?; + Ok(read) + }) + }) + } + + fn readinto1<'py>(&self, py: Python<'py>, buffer: Py) -> PyResult> { + self.readinto(py, buffer) + } + + #[pyo3(signature = (*, options=None))] + fn decode<'py>( + &self, + py: Python<'py>, + options: Option, + ) -> PyResult> { + let raw = RawResponse { + state: self.state.clone(), + }; + future_into_py(py, async move { + let payload = spawn_blocking(move || raw.materialize_result()) + .await + .map_err(|error| pyo3::exceptions::PyRuntimeError::new_err(error.to_string()))? + .map_err(raw_read_error_to_py)?; + let (value, opts) = + decode_core_value(bytes::Bytes::from(payload), options.as_ref()) + .map_err(crate::errors::to_py_err)?; + Python::attach(|py| core_value_to_python_with_opts(py, value, opts)) + }) + } +} + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::()?; + module.add_class::()?; + Ok(()) +} diff --git a/crates/qroissant-python/src/repr/cell.rs b/crates/qroissant-python/src/repr/cell.rs new file mode 100644 index 0000000..631df8c --- /dev/null +++ b/crates/qroissant-python/src/repr/cell.rs @@ -0,0 +1,437 @@ +//! Cell-level value formatting for q atoms and vector items. +//! +//! Converts raw q IPC values (CoreValue primitives) to human-readable strings +//! without any Arrow dependency. Null sentinels are rendered as `"null"`. +//! Temporal values use ISO-like formats familiar to both q and Python users. + +use chrono::NaiveDate; +use chrono::NaiveDateTime; +use qroissant_core::Atom; +use qroissant_core::VectorData; +use qroissant_kernels::DATE_OFFSET_DAYS; +use qroissant_kernels::MILLIS_PER_DAY; +use qroissant_kernels::Q_NULL_DATE; +use qroissant_kernels::Q_NULL_INT; +use qroissant_kernels::Q_NULL_LONG; +use qroissant_kernels::Q_NULL_MINUTE; +use qroissant_kernels::Q_NULL_MONTH; +use qroissant_kernels::Q_NULL_SECOND; +use qroissant_kernels::Q_NULL_SHORT; +use qroissant_kernels::Q_NULL_TIME; +use qroissant_kernels::Q_NULL_TIMESPAN; +use qroissant_kernels::Q_NULL_TIMESTAMP; +use qroissant_kernels::TIMESTAMP_OFFSET_NS; + +pub const MAX_CELL_CHARS: usize = 48; + +/// Truncate a string to `MAX_CELL_CHARS` characters, appending `"..."` if cut. +pub fn truncate(s: String) -> String { + let mut chars = s.chars(); + let head: String = chars.by_ref().take(MAX_CELL_CHARS).collect(); + if chars.next().is_some() { + format!("{head}...") + } else { + head + } +} + +// --------------------------------------------------------------------------- +// Temporal helpers +// --------------------------------------------------------------------------- + +fn format_date_days(q_days: i32) -> String { + // q dates are days since 2000-01-01; NaiveDate::from_ymd uses Unix days + let unix_days = q_days + DATE_OFFSET_DAYS; + match NaiveDate::from_num_days_from_ce_opt(unix_days + 719_163) { + Some(d) => d.format("%Y.%m.%d").to_string(), + None => format!(""), + } +} + +fn format_timestamp_ns(q_ns: i64) -> String { + let unix_ns = q_ns.saturating_add(TIMESTAMP_OFFSET_NS); + let secs = unix_ns.div_euclid(1_000_000_000); + let nsecs = unix_ns.rem_euclid(1_000_000_000) as u32; + match NaiveDateTime::from_timestamp_opt(secs, nsecs) { + Some(dt) => dt.format("%Y.%m.%dT%H:%M:%S.%9f").to_string(), + None => format!(""), + } +} + +fn format_month_i32(q_months: i32) -> String { + // q months are months since 2000-01; month 0 = 2000.01 + let total_months = 2000 * 12 + q_months; + let year = total_months.div_euclid(12); + let month = total_months.rem_euclid(12) + 1; + format!("{year:04}.{month:02}m") +} + +fn format_datetime_f64(q_days: f64) -> String { + let unix_ms = q_days * MILLIS_PER_DAY + 946_684_800_000.0; + let unix_ms_i64 = unix_ms as i64; + let secs = unix_ms_i64.div_euclid(1000); + let ms = unix_ms_i64.rem_euclid(1000) as u32; + match NaiveDateTime::from_timestamp_opt(secs, ms * 1_000_000) { + Some(dt) => dt.format("%Y.%m.%dT%H:%M:%S.%3f").to_string(), + None => format!(""), + } +} + +fn format_timespan_ns(q_ns: i64) -> String { + // Timespans can be negative (use absolute value then sign) + let (sign, abs_ns) = if q_ns < 0 { + ("-", (-(q_ns as i128)) as u64) + } else { + ("", q_ns as u64) + }; + let days = abs_ns / 86_400_000_000_000; + let rem = abs_ns % 86_400_000_000_000; + let hours = rem / 3_600_000_000_000; + let rem = rem % 3_600_000_000_000; + let minutes = rem / 60_000_000_000; + let rem = rem % 60_000_000_000; + let secs = rem / 1_000_000_000; + let ns = rem % 1_000_000_000; + format!("{sign}{days}D{hours:02}:{minutes:02}:{secs:02}.{ns:09}") +} + +fn format_minute_i32(total_minutes: i32) -> String { + let h = total_minutes / 60; + let m = total_minutes % 60; + format!("{h:02}:{m:02}") +} + +fn format_second_i32(total_seconds: i32) -> String { + let h = total_seconds / 3600; + let m = (total_seconds / 60) % 60; + let s = total_seconds % 60; + format!("{h:02}:{m:02}:{s:02}") +} + +fn format_time_ms(total_ms: i32) -> String { + let h = total_ms / 3_600_000; + let m = (total_ms / 60_000) % 60; + let s = (total_ms / 1000) % 60; + let ms = total_ms % 1000; + format!("{h:02}:{m:02}:{s:02}.{ms:03}") +} + +fn format_guid_bytes(bytes: &[u8; 16]) -> String { + format!( + "{:02x}{:02x}{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}-{:02x}{:02x}{:02x}{:02x}{:02x}{:02x}", + bytes[0], + bytes[1], + bytes[2], + bytes[3], + bytes[4], + bytes[5], + bytes[6], + bytes[7], + bytes[8], + bytes[9], + bytes[10], + bytes[11], + bytes[12], + bytes[13], + bytes[14], + bytes[15], + ) +} + +fn format_symbol_bytes(bytes: &[u8]) -> String { + String::from_utf8_lossy(bytes).into_owned() +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Format a q atom as a display string (no truncation applied). +pub fn format_atom_raw(atom: &Atom) -> String { + match atom { + Atom::Boolean(b) => if *b { "true" } else { "false" }.to_string(), + Atom::Guid(bytes) => format_guid_bytes(bytes), + Atom::Byte(b) => format!("0x{b:02x}"), + Atom::Short(v) => { + if *v == Q_NULL_SHORT { + "null".to_string() + } else { + v.to_string() + } + } + Atom::Int(v) => { + if *v == Q_NULL_INT { + "null".to_string() + } else { + v.to_string() + } + } + Atom::Long(v) => { + if *v == Q_NULL_LONG { + "null".to_string() + } else { + v.to_string() + } + } + Atom::Real(v) => { + if v.is_nan() { + "null".to_string() + } else { + v.to_string() + } + } + Atom::Float(v) => { + if v.is_nan() { + "null".to_string() + } else { + v.to_string() + } + } + Atom::Char(b) => { + let ch = *b as char; + format!("\"{ch}\"") + } + Atom::Symbol(bytes) => format_symbol_bytes(bytes), + Atom::Timestamp(v) => { + if *v == Q_NULL_TIMESTAMP { + "null".to_string() + } else { + format_timestamp_ns(*v) + } + } + Atom::Month(v) => { + if *v == Q_NULL_MONTH { + "null".to_string() + } else { + format_month_i32(*v) + } + } + Atom::Date(v) => { + if *v == Q_NULL_DATE { + "null".to_string() + } else { + format_date_days(*v) + } + } + Atom::Datetime(v) => { + if v.is_nan() { + "null".to_string() + } else { + format_datetime_f64(*v) + } + } + Atom::Timespan(v) => { + if *v == Q_NULL_TIMESPAN { + "null".to_string() + } else { + format_timespan_ns(*v) + } + } + Atom::Minute(v) => { + if *v == Q_NULL_MINUTE { + "null".to_string() + } else { + format_minute_i32(*v) + } + } + Atom::Second(v) => { + if *v == Q_NULL_SECOND { + "null".to_string() + } else { + format_second_i32(*v) + } + } + Atom::Time(v) => { + if *v == Q_NULL_TIME { + "null".to_string() + } else { + format_time_ms(*v) + } + } + } +} + +/// Format and truncate a q atom. +pub fn format_atom_cell(atom: &Atom) -> String { + truncate(format_atom_raw(atom)) +} + +/// Format a single element from a `VectorData` at `index` (no truncation). +pub fn format_vector_item_raw(data: &VectorData, index: usize) -> String { + match data { + VectorData::Boolean(v) => if v[index] != 0 { "true" } else { "false" }.to_string(), + VectorData::Guid(v) => { + let chunk: &[u8; 16] = v[index * 16..(index + 1) * 16].try_into().unwrap(); + format_guid_bytes(chunk) + } + VectorData::Byte(v) => format!("0x{:02x}", v[index]), + VectorData::Short(_) => { + let val = data.as_i16_slice()[index]; + if val == Q_NULL_SHORT { + "null".to_string() + } else { + val.to_string() + } + } + VectorData::Int(_) => { + let val = data.as_i32_slice()[index]; + if val == Q_NULL_INT { + "null".to_string() + } else { + val.to_string() + } + } + VectorData::Long(_) => { + let val = data.as_i64_slice()[index]; + if val == Q_NULL_LONG { + "null".to_string() + } else { + val.to_string() + } + } + VectorData::Real(_) => { + let val = data.as_f32_slice()[index]; + if val.is_nan() { + "null".to_string() + } else { + val.to_string() + } + } + VectorData::Float(_) => { + let val = data.as_f64_slice()[index]; + if val.is_nan() { + "null".to_string() + } else { + val.to_string() + } + } + VectorData::Char(v) => { + let ch = v[index] as char; + ch.to_string() + } + VectorData::Symbol(v) => format_symbol_bytes(&v[index]), + VectorData::Timestamp(_) => { + let val = data.as_i64_slice()[index]; + if val == Q_NULL_TIMESTAMP { + "null".to_string() + } else { + format_timestamp_ns(val) + } + } + VectorData::Month(_) => { + let val = data.as_i32_slice()[index]; + if val == Q_NULL_MONTH { + "null".to_string() + } else { + format_month_i32(val) + } + } + VectorData::Date(_) => { + let val = data.as_i32_slice()[index]; + if val == Q_NULL_DATE { + "null".to_string() + } else { + format_date_days(val) + } + } + VectorData::Datetime(_) => { + let val = data.as_f64_slice()[index]; + if val.is_nan() { + "null".to_string() + } else { + format_datetime_f64(val) + } + } + VectorData::Timespan(_) => { + let val = data.as_i64_slice()[index]; + if val == Q_NULL_TIMESPAN { + "null".to_string() + } else { + format_timespan_ns(val) + } + } + VectorData::Minute(_) => { + let val = data.as_i32_slice()[index]; + if val == Q_NULL_MINUTE { + "null".to_string() + } else { + format_minute_i32(val) + } + } + VectorData::Second(_) => { + let val = data.as_i32_slice()[index]; + if val == Q_NULL_SECOND { + "null".to_string() + } else { + format_second_i32(val) + } + } + VectorData::Time(_) => { + let val = data.as_i32_slice()[index]; + if val == Q_NULL_TIME { + "null".to_string() + } else { + format_time_ms(val) + } + } + } +} + +/// Format and truncate a single vector item. +pub fn format_vector_item(data: &VectorData, index: usize) -> String { + truncate(format_vector_item_raw(data, index)) +} + +/// Format a char vector as a quoted string (e.g. `"abc"`), truncated. +pub fn format_char_vector(data: &[u8]) -> String { + let s: String = data.iter().map(|&b| b as char).collect(); + truncate(format!("\"{s}\"")) +} + +/// Return the q primitive label for a `VectorData`. +pub fn primitive_label(data: &VectorData) -> &'static str { + match data { + VectorData::Boolean(_) => "boolean", + VectorData::Guid(_) => "guid", + VectorData::Byte(_) => "byte", + VectorData::Short(_) => "short", + VectorData::Int(_) => "int", + VectorData::Long(_) => "long", + VectorData::Real(_) => "real", + VectorData::Float(_) => "float", + VectorData::Char(_) => "char", + VectorData::Symbol(_) => "symbol", + VectorData::Timestamp(_) => "timestamp", + VectorData::Month(_) => "month", + VectorData::Date(_) => "date", + VectorData::Datetime(_) => "datetime", + VectorData::Timespan(_) => "timespan", + VectorData::Minute(_) => "minute", + VectorData::Second(_) => "second", + VectorData::Time(_) => "time", + } +} + +/// Return the q primitive label for an `Atom`. +pub fn atom_primitive_label(atom: &Atom) -> &'static str { + match atom { + Atom::Boolean(_) => "boolean", + Atom::Guid(_) => "guid", + Atom::Byte(_) => "byte", + Atom::Short(_) => "short", + Atom::Int(_) => "int", + Atom::Long(_) => "long", + Atom::Real(_) => "real", + Atom::Float(_) => "float", + Atom::Char(_) => "char", + Atom::Symbol(_) => "symbol", + Atom::Timestamp(_) => "timestamp", + Atom::Month(_) => "month", + Atom::Date(_) => "date", + Atom::Datetime(_) => "datetime", + Atom::Timespan(_) => "timespan", + Atom::Minute(_) => "minute", + Atom::Second(_) => "second", + Atom::Time(_) => "time", + } +} diff --git a/crates/qroissant-python/src/repr/format.rs b/crates/qroissant-python/src/repr/format.rs new file mode 100644 index 0000000..32f6685 --- /dev/null +++ b/crates/qroissant-python/src/repr/format.rs @@ -0,0 +1,278 @@ +//! High-level format functions for each q value shape. +//! +//! Each function produces a multi-line ASCII repr string. Rendering is driven +//! by the active [`FormattingOptions`] (read from the process-wide global). + +use qroissant_core::Atom as CoreAtom; +use qroissant_core::Dictionary as CoreDictionary; +use qroissant_core::List as CoreList; +use qroissant_core::Table as CoreTable; +use qroissant_core::Value as CoreValue; +use qroissant_core::Vector as CoreVector; +use qroissant_core::VectorData; + +use super::cell::atom_primitive_label; +use super::cell::format_atom_cell; +use super::cell::format_atom_raw; +use super::cell::format_char_vector; +use super::cell::format_vector_item; +use super::cell::primitive_label; +use super::cell::truncate; +use super::options::active_options; +use super::render::PreviewSlot; +use super::render::preview_slots; +use super::render::render_preview; + +// --------------------------------------------------------------------------- +// Attribute helper +// --------------------------------------------------------------------------- + +fn attribute_label(attribute: qroissant_core::Attribute) -> &'static str { + match attribute { + qroissant_core::Attribute::None => "none", + qroissant_core::Attribute::Sorted => "sorted", + qroissant_core::Attribute::Unique => "unique", + qroissant_core::Attribute::Parted => "parted", + qroissant_core::Attribute::Grouped => "grouped", + } +} + +// --------------------------------------------------------------------------- +// Atom +// --------------------------------------------------------------------------- + +pub fn format_atom(atom: &CoreAtom) -> String { + let label = atom_primitive_label(atom); + render_preview( + vec![format!("Atom [{label}]")], + vec!["value".to_string()], + vec![vec![format_atom_cell(atom)]], + vec!["shape: (1,)".to_string()], + ) +} + +// --------------------------------------------------------------------------- +// Vector +// --------------------------------------------------------------------------- + +pub fn format_vector(vector: &CoreVector) -> String { + let len = vector.len(); + let data = vector.data(); + let label = primitive_label(data); + let attr = vector.attribute(); + + let rows = match data { + VectorData::Char(chars) => { + vec![vec![format_char_vector(chars)]] + } + _ => { + let opts = active_options(); + preview_slots(len, opts.max_rows, opts.row_display) + .into_iter() + .map(|slot| match slot { + PreviewSlot::Index(i) => vec![format_vector_item(data, i)], + PreviewSlot::Ellipsis => vec!["...".to_string()], + }) + .collect() + } + }; + + render_preview( + vec![format!("Vector [{label}, attr={}]", attribute_label(attr))], + vec!["value".to_string()], + rows, + vec![format!("shape: ({len},)")], + ) +} + +// --------------------------------------------------------------------------- +// List +// --------------------------------------------------------------------------- + +/// Compact single-line summary of any `CoreValue` (used for list/dict cells). +fn inline_value_summary(value: &CoreValue) -> String { + match value { + CoreValue::Atom(atom) => truncate(format!( + "{} [{}]", + format_atom_raw(atom), + atom_primitive_label(atom) + )), + CoreValue::Vector(vector) => { + let label = primitive_label(vector.data()); + let len = vector.len(); + match vector.data() { + VectorData::Char(chars) => truncate(format_char_vector(chars)), + _ => truncate(format!("vector<{label}>[{len}]")), + } + } + CoreValue::List(list) => truncate(format!("list[{}]", list.len())), + CoreValue::Dictionary(dict) => truncate(format!("dict[{}]", dict.len())), + CoreValue::Table(table) => { + truncate(format!("table[{}x{}]", table.len(), table.num_columns())) + } + CoreValue::UnaryPrimitive { opcode } => truncate(format!("unary(0x{opcode:02x})")), + } +} + +pub fn format_list(list: &CoreList) -> String { + let len = list.len(); + let opts = active_options(); + let attr = list.attribute(); + + let rows = preview_slots(len, opts.max_rows, opts.row_display) + .into_iter() + .map(|slot| match slot { + PreviewSlot::Index(i) => vec![inline_value_summary(&list.values()[i])], + PreviewSlot::Ellipsis => vec!["...".to_string()], + }) + .collect(); + + render_preview( + vec![format!("List [list, attr={}]", attribute_label(attr))], + vec!["value".to_string()], + rows, + vec![format!("shape: ({len},)")], + ) +} + +// --------------------------------------------------------------------------- +// Dictionary +// --------------------------------------------------------------------------- + +pub fn format_dictionary(dict: &CoreDictionary) -> String { + let size = dict.len(); + let sorted = dict.sorted(); + + let all_rows = vec![ + vec!["keys".to_string(), inline_value_summary(dict.keys())], + vec!["values".to_string(), inline_value_summary(dict.values())], + ]; + + let opts = active_options(); + let rows = preview_slots(all_rows.len(), opts.max_rows, opts.row_display) + .into_iter() + .map(|slot| match slot { + PreviewSlot::Index(i) => all_rows[i].clone(), + PreviewSlot::Ellipsis => vec!["...".to_string(), "...".to_string()], + }) + .collect(); + + render_preview( + vec![format!("Dictionary [dict, sorted={sorted}]")], + vec!["part".to_string(), "value".to_string()], + rows, + vec![format!("shape: ({size},)")], + ) +} + +// --------------------------------------------------------------------------- +// Table +// --------------------------------------------------------------------------- + +fn column_primitive_label(col: &CoreValue) -> &'static str { + match col { + CoreValue::Vector(v) => primitive_label(v.data()), + CoreValue::List(_) => "list", + CoreValue::Atom(_) => "atom", + _ => "?", + } +} + +fn table_cell(col: &CoreValue, row_index: usize) -> String { + match col { + CoreValue::Vector(v) => match v.data() { + VectorData::Char(chars) => { + // Show a single char per cell + if row_index < chars.len() { + (chars[row_index] as char).to_string() + } else { + "?".to_string() + } + } + data => format_vector_item(data, row_index), + }, + CoreValue::Atom(atom) => format_atom_cell(atom), + CoreValue::List(list) => { + if row_index < list.len() { + inline_value_summary(&list.values()[row_index]) + } else { + "?".to_string() + } + } + _ => inline_value_summary(col), + } +} + +fn column_name(raw: &[u8]) -> String { + String::from_utf8_lossy(raw).into_owned() +} + +pub fn format_table(table: &CoreTable) -> String { + let num_rows = table.len(); + let num_cols = table.num_columns(); + let opts = active_options(); + let visible_cols = num_cols.min(opts.max_columns); + + // Build headers: "name\ntype" for each visible column + let mut headers: Vec = table + .column_names() + .iter() + .zip(table.columns().iter()) + .take(visible_cols) + .map(|(name, col)| { + let col_name = truncate(column_name(name)); + let type_label = column_primitive_label(col); + format!("{col_name}\n{type_label}") + }) + .collect(); + + if num_cols > visible_cols { + headers.push("...\n...".to_string()); + } else if headers.is_empty() { + headers.push("value".to_string()); + } + + // Build rows + let row_slots = preview_slots(num_rows, opts.max_rows, opts.row_display); + let columns = table.columns(); + + let body_rows: Vec> = row_slots + .into_iter() + .map(|slot| { + let mut row: Vec = match slot { + PreviewSlot::Index(row_i) => (0..visible_cols) + .map(|col_i| table_cell(&columns[col_i], row_i)) + .collect(), + PreviewSlot::Ellipsis => vec!["...".to_string(); visible_cols.max(1)], + }; + if num_cols > visible_cols { + row.push("...".to_string()); + } + row + }) + .collect(); + + render_preview( + vec![format!( + "Table [table, attr={}]", + attribute_label(table.attribute()) + )], + headers, + body_rows, + vec![format!("shape: ({num_rows}, {num_cols})")], + ) +} + +// --------------------------------------------------------------------------- +// UnaryPrimitive +// --------------------------------------------------------------------------- + +#[allow(dead_code)] +pub fn format_unary_primitive(opcode: i8) -> String { + render_preview( + vec!["UnaryPrimitive [unary_primitive]".to_string()], + vec!["opcode".to_string()], + vec![vec![format!("0x{opcode:02x}")]], + vec!["shape: (1,)".to_string()], + ) +} diff --git a/crates/qroissant-python/src/repr/mod.rs b/crates/qroissant-python/src/repr/mod.rs new file mode 100644 index 0000000..3aac13b --- /dev/null +++ b/crates/qroissant-python/src/repr/mod.rs @@ -0,0 +1,26 @@ +//! Pretty repr system for qroissant Python values. +//! +//! This module provides: +//! - [`options`] — global `FormattingOptions`, `RowDisplay`, and associated +//! builder and pyfunctions (`get_formatting_options`, `set_formatting_options`, +//! `reset_formatting_options`). +//! - [`cell`] — individual q value → string conversion without Arrow. +//! - [`render`] — ASCII table rendering via `tabled` and `preview_slots`. +//! - [`format`] — shape-level formatting functions called by `__repr__`/`__str__`. + +pub mod cell; +pub mod format; +pub mod options; +pub mod render; + +pub use format::format_atom; +pub use format::format_dictionary; +pub use format::format_list; +pub use format::format_table; +pub use format::format_vector; +use pyo3::prelude::*; +use pyo3::types::PyModule; + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + options::register(module) +} diff --git a/crates/qroissant-python/src/repr/options.rs b/crates/qroissant-python/src/repr/options.rs new file mode 100644 index 0000000..6f3488c --- /dev/null +++ b/crates/qroissant-python/src/repr/options.rs @@ -0,0 +1,172 @@ +//! Global repr formatting options and associated Python types. + +use std::sync::OnceLock; +use std::sync::RwLock; + +use pyo3::prelude::*; +use pyo3::types::PyModule; + +/// Row selection strategy used by qroissant repr formatting. +#[derive(PartialEq, Eq, Default, Clone, Copy, Debug)] +#[pyclass( + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE", + module = "qroissant" +)] +pub enum RowDisplay { + /// Show the first `max_rows` rows followed by an ellipsis when truncated. + #[default] + Head, + /// Show the first half and last half of rows with an ellipsis in the middle. + HeadTail, +} + +#[pymethods] +impl RowDisplay { + fn __repr__(&self) -> &'static str { + match self { + Self::Head => "RowDisplay.HEAD", + Self::HeadTail => "RowDisplay.HEAD_TAIL", + } + } +} + +/// Formatting options for user-facing qroissant string representations. +/// +/// Notes +/// ----- +/// These options control how qroissant values render through `str(...)` and +/// `repr(...)`. Apply them process-wide through `set_formatting_options(...)`. +#[pyclass(get_all, eq, frozen, skip_from_py_object, module = "qroissant")] +#[derive(PartialEq, Eq, Clone, Debug)] +pub struct FormattingOptions { + pub max_rows: usize, + pub max_columns: usize, + pub row_display: RowDisplay, +} + +impl Default for FormattingOptions { + fn default() -> Self { + Self { + max_rows: 8, + max_columns: 6, + row_display: RowDisplay::Head, + } + } +} + +#[pymethods] +impl FormattingOptions { + #[staticmethod] + /// Create a builder initialized with qroissant's default formatting policy. + fn builder() -> FormattingOptionsBuilder { + FormattingOptionsBuilder::default() + } + + fn __repr__(&self) -> String { + format!( + "FormattingOptions(max_rows={}, max_columns={}, row_display={})", + self.max_rows, + self.max_columns, + self.row_display.__repr__(), + ) + } +} + +/// Builder for [`FormattingOptions`]. +#[pyclass(skip_from_py_object, module = "qroissant")] +#[derive(Default, Clone, Debug)] +pub struct FormattingOptionsBuilder { + inner: FormattingOptions, +} + +#[pymethods] +impl FormattingOptionsBuilder { + #[pyo3(signature = (value, /))] + fn with_max_rows(&self, value: usize) -> Self { + let mut b = self.clone(); + b.inner.max_rows = value; + b + } + + #[pyo3(signature = (value, /))] + fn with_max_columns(&self, value: usize) -> Self { + let mut b = self.clone(); + b.inner.max_columns = value; + b + } + + #[pyo3(signature = (value, /))] + fn with_row_display(&self, value: RowDisplay) -> Self { + let mut b = self.clone(); + b.inner.row_display = value; + b + } + + /// Finalize the builder into an immutable `FormattingOptions` instance. + fn build(&self) -> FormattingOptions { + self.inner.clone() + } + + fn __repr__(&self) -> String { + format!("FormattingOptionsBuilder({})", self.inner.__repr__()) + } +} + +// --------------------------------------------------------------------------- +// Global state +// --------------------------------------------------------------------------- + +fn options_lock() -> &'static RwLock { + static OPTIONS: OnceLock> = OnceLock::new(); + OPTIONS.get_or_init(|| RwLock::new(FormattingOptions::default())) +} + +pub fn active_options() -> FormattingOptions { + match options_lock().read() { + Ok(guard) => guard.clone(), + Err(poisoned) => poisoned.into_inner().clone(), + } +} + +fn store_options(options: FormattingOptions) { + match options_lock().write() { + Ok(mut guard) => *guard = options, + Err(poisoned) => *poisoned.into_inner() = options, + } +} + +// --------------------------------------------------------------------------- +// Python-visible functions +// --------------------------------------------------------------------------- + +#[pyfunction] +/// Return the active qroissant repr formatting options. +pub fn get_formatting_options() -> FormattingOptions { + active_options() +} + +#[pyfunction] +#[pyo3(signature = (options, /))] +/// Update the active qroissant repr formatting options. +pub fn set_formatting_options(options: PyRef<'_, FormattingOptions>) { + store_options(options.clone()); +} + +#[pyfunction] +/// Restore qroissant's default repr formatting options. +pub fn reset_formatting_options() { + store_options(FormattingOptions::default()); +} + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_function(wrap_pyfunction!(get_formatting_options, module)?)?; + module.add_function(wrap_pyfunction!(set_formatting_options, module)?)?; + module.add_function(wrap_pyfunction!(reset_formatting_options, module)?)?; + Ok(()) +} diff --git a/crates/qroissant-python/src/repr/render.rs b/crates/qroissant-python/src/repr/render.rs new file mode 100644 index 0000000..3f7feb8 --- /dev/null +++ b/crates/qroissant-python/src/repr/render.rs @@ -0,0 +1,80 @@ +//! ASCII table rendering via the `tabled` crate and row-slot utilities. + +use tabled::builder::Builder; +use tabled::settings::Alignment; +use tabled::settings::Modify; +use tabled::settings::Style; +use tabled::settings::object::Rows; +use tabled::settings::style::HorizontalLine; + +use super::options::RowDisplay; + +/// A slot in a preview: either a concrete row index or an ellipsis separator. +#[derive(Clone, Copy)] +pub enum PreviewSlot { + Index(usize), + Ellipsis, +} + +/// Compute the row slots to show when rendering at most `max_rows` out of +/// `total`, using `row_display` to decide whether to use head or head+tail. +pub fn preview_slots(total: usize, max_rows: usize, row_display: RowDisplay) -> Vec { + if total == 0 || max_rows == 0 { + return Vec::new(); + } + + if total <= max_rows { + return (0..total).map(PreviewSlot::Index).collect(); + } + + match row_display { + RowDisplay::Head => { + let mut slots = (0..max_rows).map(PreviewSlot::Index).collect::>(); + slots.push(PreviewSlot::Ellipsis); + slots + } + RowDisplay::HeadTail if max_rows == 1 => { + vec![PreviewSlot::Index(0), PreviewSlot::Ellipsis] + } + RowDisplay::HeadTail => { + let head = max_rows.div_ceil(2); + let tail = max_rows / 2; + let mut slots = (0..head).map(PreviewSlot::Index).collect::>(); + slots.push(PreviewSlot::Ellipsis); + let tail_start = total.saturating_sub(tail); + slots.extend((tail_start..total).map(PreviewSlot::Index)); + slots + } + } +} + +/// Build an ASCII table with a modern style and a horizontal line after the +/// header row. +pub fn render_table(headers: Vec, rows: Vec>) -> String { + let mut builder = Builder::default(); + builder.push_record(headers); + for row in rows { + builder.push_record(row); + } + let mut table = builder.build(); + table.with( + Style::modern() + .remove_horizontal() + .horizontals([(1, HorizontalLine::inherit(Style::modern()))]), + ); + table.with(Modify::new(Rows::first()).with(Alignment::left())); + table.to_string() +} + +/// Assemble a full repr block: optional title lines, a table, optional footer. +pub fn render_preview( + title_lines: Vec, + headers: Vec, + rows: Vec>, + footer_lines: Vec, +) -> String { + let mut sections = title_lines; + sections.push(render_table(headers, rows)); + sections.extend(footer_lines); + sections.join("\n") +} diff --git a/crates/qroissant-python/src/serde.rs b/crates/qroissant-python/src/serde.rs new file mode 100644 index 0000000..397faae --- /dev/null +++ b/crates/qroissant-python/src/serde.rs @@ -0,0 +1,215 @@ +use std::sync::Arc; + +use pyo3::prelude::*; +use pyo3::types::PyAny; +use pyo3::types::PyBytes; +use qroissant_arrow::ListProjection; +use qroissant_arrow::ProjectionOptions; +use qroissant_arrow::StringProjection; +use qroissant_arrow::SymbolProjection; +use qroissant_core::DecodeOptions as CoreDecodeOptions; +use qroissant_core::Value as CoreValue; +use qroissant_core::decode_message_with_options; +use qroissant_core::encode_message; +use qroissant_transport::extract_q_error; + +use crate::errors::PythonError; +use crate::errors::PythonResult; +use crate::errors::to_py_err; +use crate::types::Compression; +use crate::types::DecodeOptions; +use crate::types::EncodeOptions; +use crate::types::Encoding; +use crate::types::ListInterpretation; +use crate::types::MessageType; +use crate::types::StringInterpretation; +use crate::types::SymbolInterpretation; +use crate::values::core_value_to_python_with_opts; +use crate::values::python_to_core_value; + +/// Maps Python-facing "Interpretation" options to Rust-internal "Projection" options. +/// +/// The Python API uses "Interpretation" (e.g. `SymbolInterpretation`) as it describes +/// how the user wants data to be interpreted. The Rust/Arrow layer uses "Projection" +/// (e.g. `SymbolProjection`) as it describes how values are projected into Arrow arrays. +/// Both refer to the same concept viewed from different perspectives. +pub fn decode_options_to_proj_opts(opts: Option<&DecodeOptions>) -> Arc { + let opts = opts.map(|o| o.clone()).unwrap_or_default(); + Arc::new(ProjectionOptions { + symbol: match opts.symbol_interpretation_value() { + SymbolInterpretation::Utf8 => SymbolProjection::Utf8, + SymbolInterpretation::LargeUtf8 => SymbolProjection::LargeUtf8, + SymbolInterpretation::Utf8View => SymbolProjection::Utf8View, + SymbolInterpretation::Dictionary => SymbolProjection::Dictionary, + SymbolInterpretation::RawBytes => SymbolProjection::RawBytes, + }, + string: match opts.string_interpretation_value() { + StringInterpretation::Utf8 => StringProjection::Utf8, + StringInterpretation::Binary => StringProjection::Binary, + }, + list: match opts.list_interpretation_value() { + ListInterpretation::List => ListProjection::List, + ListInterpretation::LargeList => ListProjection::LargeList, + ListInterpretation::ListView => ListProjection::ListView, + }, + union_mode: match opts.union_mode_value() { + crate::types::UnionMode::Dense => qroissant_arrow::UnionMode::Dense, + crate::types::UnionMode::Sparse => qroissant_arrow::UnionMode::Sparse, + }, + treat_infinity_as_null: opts.treat_infinity_as_null(), + parallel: opts.parallel_value(), + assume_symbol_utf8: opts.assume_symbol_utf8_value(), + }) +} + +fn decode_options_to_core(opts: &DecodeOptions) -> CoreDecodeOptions { + CoreDecodeOptions { + parallel: opts.parallel_value(), + ..CoreDecodeOptions::default() + } +} + +fn ensure_default_encode_options(options: Option<&EncodeOptions>) -> PythonResult<()> { + if let Some(options) = options + && options != &EncodeOptions::default() + { + return Err(PythonError::NotImplemented( + "custom encode options are not implemented yet".to_string(), + )); + } + Ok(()) +} + +pub fn decode_core_value( + payload: bytes::Bytes, + options: Option<&DecodeOptions>, +) -> PythonResult<(CoreValue, Arc)> { + if let Some(message) = + extract_q_error(payload.as_ref()).map_err(crate::errors::map_transport_error)? + { + return Err(PythonError::QRuntime(message)); + } + let core_opts = options.map(decode_options_to_core).unwrap_or_default(); + let decoded = decode_message_with_options(payload, &core_opts) + .map_err(|error| PythonError::Decode(error.to_string()))?; + let proj_opts = decode_options_to_proj_opts(options); + let (_header, value) = decoded.into_parts(); + Ok((value, proj_opts)) +} + +/// Wraps a Python `bytes` object in a [`bytes::Bytes`] without copying. +/// +/// CPython `bytes` objects are immutable and their backing memory is never +/// moved, so it is sound to hold a raw pointer into them as long as the +/// `Py` reference (which increments the CPython refcount) is alive. +struct PinnedPyBytes { + _owner: Py, + ptr: *const u8, + len: usize, +} + +// SAFETY: `Py` is `Send`, and the pointed-to memory is immutable. +unsafe impl Send for PinnedPyBytes {} +// SAFETY: The data is immutable and the owner keeps it alive. +unsafe impl Sync for PinnedPyBytes {} + +impl AsRef<[u8]> for PinnedPyBytes { + #[inline] + fn as_ref(&self) -> &[u8] { + // SAFETY: `ptr` is valid for `len` bytes while `_owner` keeps the + // CPython bytes object alive (refcount > 0, no deallocation possible). + unsafe { std::slice::from_raw_parts(self.ptr, self.len) } + } +} + +/// Minimum payload size for the zero-copy `PinnedPyBytes` path. +/// +/// For small payloads the `Arc` allocation inside `Bytes::from_owner` costs +/// more than a plain `memcpy`, so we fall back to copying below this threshold. +const ZERO_COPY_MIN_BYTES: usize = 32 * 1024; // 32 KB + +/// Converts a Python `bytes`-like object into a [`bytes::Bytes`]. +/// +/// For plain `bytes` objects ≥ [`ZERO_COPY_MIN_BYTES`] the underlying buffer +/// is **borrowed without copying** via [`bytes::Bytes::from_owner`]. +/// Smaller payloads and other buffer protocols (bytearray, memoryview) take a +/// single copy — same cost as before. +fn payload_to_bytes(payload: &Bound<'_, PyAny>) -> PyResult { + if let Ok(pb) = payload.downcast::() { + let data = pb.as_bytes(); + if data.len() >= ZERO_COPY_MIN_BYTES { + let pinned = PinnedPyBytes { + _owner: pb.clone().unbind(), + ptr: data.as_ptr(), + len: data.len(), + }; + return Ok(bytes::Bytes::from_owner(pinned)); + } + return Ok(bytes::Bytes::copy_from_slice(data)); + } + Ok(bytes::Bytes::from(payload.extract::>()?)) +} + +pub fn encode_core_value_bytes( + value: &CoreValue, + options: Option<&EncodeOptions>, + encoding: Encoding, + message_type: MessageType, + compression: Compression, +) -> PythonResult> { + ensure_default_encode_options(options)?; + encode_message( + value, + encoding.into(), + message_type.into(), + compression.into(), + ) + .map_err(|error| PythonError::Protocol(error.to_string())) +} + +#[pyfunction] +#[pyo3(signature = (payload, /, *, options=None))] +pub fn decode( + py: Python<'_>, + payload: &Bound<'_, PyAny>, + options: Option<&DecodeOptions>, +) -> PyResult> { + let bytes = payload_to_bytes(payload)?; + let options_clone = options.cloned(); + let (value, proj_opts) = py + .detach(|| decode_core_value(bytes, options_clone.as_ref())) + .map_err(to_py_err)?; + core_value_to_python_with_opts(py, value, proj_opts) +} + +#[pyfunction] +#[pyo3(signature = (value, /, *, options=None, encoding=Encoding::LittleEndian, message_type=MessageType::Asynchronous, compression=Compression::Uncompressed))] +pub fn encode( + py: Python<'_>, + value: &Bound<'_, PyAny>, + options: Option<&EncodeOptions>, + encoding: Encoding, + message_type: MessageType, + compression: Compression, +) -> PyResult> { + let value = python_to_core_value(value)?; + let options_clone = options.cloned(); + let payload = py + .detach(|| { + encode_core_value_bytes( + &value, + options_clone.as_ref(), + encoding, + message_type, + compression, + ) + }) + .map_err(to_py_err)?; + Ok(PyBytes::new(py, &payload).unbind()) +} + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_function(wrap_pyfunction!(decode, module)?)?; + module.add_function(wrap_pyfunction!(encode, module)?)?; + Ok(()) +} diff --git a/crates/qroissant-python/src/types.rs b/crates/qroissant-python/src/types.rs new file mode 100644 index 0000000..4f2eabd --- /dev/null +++ b/crates/qroissant-python/src/types.rs @@ -0,0 +1,1325 @@ +use std::collections::BTreeMap; + +use pyo3::prelude::*; +use qroissant_core::ValueType; + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Attribute { + None, + Sorted, + Unique, + Parted, + Grouped, +} + +#[pymethods] +impl Attribute {} + +impl From for Attribute { + fn from(value: qroissant_core::Attribute) -> Self { + match value { + qroissant_core::Attribute::None => Self::None, + qroissant_core::Attribute::Sorted => Self::Sorted, + qroissant_core::Attribute::Unique => Self::Unique, + qroissant_core::Attribute::Parted => Self::Parted, + qroissant_core::Attribute::Grouped => Self::Grouped, + } + } +} + +impl From for qroissant_core::Attribute { + fn from(value: Attribute) -> Self { + match value { + Attribute::None => Self::None, + Attribute::Sorted => Self::Sorted, + Attribute::Unique => Self::Unique, + Attribute::Parted => Self::Parted, + Attribute::Grouped => Self::Grouped, + } + } +} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Shape { + Atom, + Vector, + List, + Dictionary, + Table, + UnaryPrimitive, + Error, +} + +#[pymethods] +impl Shape {} + +impl From for Shape { + fn from(value: qroissant_core::Shape) -> Self { + match value { + qroissant_core::Shape::Atom => Self::Atom, + qroissant_core::Shape::Vector => Self::Vector, + qroissant_core::Shape::List => Self::List, + qroissant_core::Shape::Dictionary => Self::Dictionary, + qroissant_core::Shape::Table => Self::Table, + qroissant_core::Shape::UnaryPrimitive => Self::UnaryPrimitive, + qroissant_core::Shape::Error => Self::Error, + } + } +} + +impl From for qroissant_core::Shape { + fn from(value: Shape) -> Self { + match value { + Shape::Atom => Self::Atom, + Shape::Vector => Self::Vector, + Shape::List => Self::List, + Shape::Dictionary => Self::Dictionary, + Shape::Table => Self::Table, + Shape::UnaryPrimitive => Self::UnaryPrimitive, + Shape::Error => Self::Error, + } + } +} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Primitive { + Boolean, + Guid, + Byte, + Short, + Int, + Long, + Real, + Float, + Char, + Symbol, + Timestamp, + Month, + Date, + Datetime, + Timespan, + Minute, + Second, + Time, + Mixed, +} + +#[pymethods] +impl Primitive {} + +impl From for Primitive { + fn from(value: qroissant_core::Primitive) -> Self { + match value { + qroissant_core::Primitive::Boolean => Self::Boolean, + qroissant_core::Primitive::Guid => Self::Guid, + qroissant_core::Primitive::Byte => Self::Byte, + qroissant_core::Primitive::Short => Self::Short, + qroissant_core::Primitive::Int => Self::Int, + qroissant_core::Primitive::Long => Self::Long, + qroissant_core::Primitive::Real => Self::Real, + qroissant_core::Primitive::Float => Self::Float, + qroissant_core::Primitive::Char => Self::Char, + qroissant_core::Primitive::Symbol => Self::Symbol, + qroissant_core::Primitive::Timestamp => Self::Timestamp, + qroissant_core::Primitive::Month => Self::Month, + qroissant_core::Primitive::Date => Self::Date, + qroissant_core::Primitive::Datetime => Self::Datetime, + qroissant_core::Primitive::Timespan => Self::Timespan, + qroissant_core::Primitive::Minute => Self::Minute, + qroissant_core::Primitive::Second => Self::Second, + qroissant_core::Primitive::Time => Self::Time, + qroissant_core::Primitive::Mixed => Self::Mixed, + } + } +} + +impl From for qroissant_core::Primitive { + fn from(value: Primitive) -> Self { + match value { + Primitive::Boolean => Self::Boolean, + Primitive::Guid => Self::Guid, + Primitive::Byte => Self::Byte, + Primitive::Short => Self::Short, + Primitive::Int => Self::Int, + Primitive::Long => Self::Long, + Primitive::Real => Self::Real, + Primitive::Float => Self::Float, + Primitive::Char => Self::Char, + Primitive::Symbol => Self::Symbol, + Primitive::Timestamp => Self::Timestamp, + Primitive::Month => Self::Month, + Primitive::Date => Self::Date, + Primitive::Datetime => Self::Datetime, + Primitive::Timespan => Self::Timespan, + Primitive::Minute => Self::Minute, + Primitive::Second => Self::Second, + Primitive::Time => Self::Time, + Primitive::Mixed => Self::Mixed, + } + } +} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Encoding { + LittleEndian, + BigEndian, +} + +#[pymethods] +impl Encoding {} + +impl From for Encoding { + fn from(value: qroissant_core::Encoding) -> Self { + match value { + qroissant_core::Encoding::LittleEndian => Self::LittleEndian, + qroissant_core::Encoding::BigEndian => Self::BigEndian, + } + } +} + +impl From for qroissant_core::Encoding { + fn from(value: Encoding) -> Self { + match value { + Encoding::LittleEndian => Self::LittleEndian, + Encoding::BigEndian => Self::BigEndian, + } + } +} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Compression { + Uncompressed, + Compressed, + CompressedLarge, +} + +#[pymethods] +impl Compression {} + +impl From for Compression { + fn from(value: qroissant_core::Compression) -> Self { + match value { + qroissant_core::Compression::Uncompressed => Self::Uncompressed, + qroissant_core::Compression::Compressed => Self::Compressed, + qroissant_core::Compression::CompressedLarge => Self::CompressedLarge, + } + } +} + +impl From for qroissant_core::Compression { + fn from(value: Compression) -> Self { + match value { + Compression::Uncompressed => Self::Uncompressed, + Compression::Compressed => Self::Compressed, + Compression::CompressedLarge => Self::CompressedLarge, + } + } +} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum MessageType { + Asynchronous, + Synchronous, + Response, +} + +#[pymethods] +impl MessageType {} + +impl From for MessageType { + fn from(value: qroissant_core::MessageType) -> Self { + match value { + qroissant_core::MessageType::Asynchronous => Self::Asynchronous, + qroissant_core::MessageType::Synchronous => Self::Synchronous, + qroissant_core::MessageType::Response => Self::Response, + } + } +} + +impl From for qroissant_core::MessageType { + fn from(value: MessageType) -> Self { + match value { + MessageType::Asynchronous => Self::Asynchronous, + MessageType::Synchronous => Self::Synchronous, + MessageType::Response => Self::Response, + } + } +} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] +pub enum SymbolInterpretation { + #[default] + Utf8, + LargeUtf8, + Utf8View, + Dictionary, + RawBytes, +} + +#[pymethods] +impl SymbolInterpretation {} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] +pub enum ListInterpretation { + #[default] + List, + LargeList, + ListView, +} + +#[pymethods] +impl ListInterpretation {} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] +pub enum StringInterpretation { + #[default] + Utf8, + Binary, +} + +#[pymethods] +impl StringInterpretation {} + +#[pyclass( + module = "qroissant", + eq, + eq_int, + frozen, + rename_all = "SCREAMING_SNAKE_CASE" +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] +pub enum UnionMode { + #[default] + Dense, + Sparse, +} + +#[pymethods] +impl UnionMode {} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Type { + primitive: Option, + shape: Shape, + attribute: Option, + sorted: Option, +} + +#[pymethods] +impl Type { + #[new] + #[pyo3(signature = (primitive, shape, attribute=None, sorted=None))] + fn new( + primitive: Option, + shape: Shape, + attribute: Option, + sorted: Option, + ) -> Self { + Self { + primitive, + shape, + attribute, + sorted, + } + } + + #[getter] + fn primitive(&self) -> Option { + self.primitive + } + + #[getter] + fn shape(&self) -> Shape { + self.shape + } + + #[getter] + fn attribute(&self) -> Option { + self.attribute + } + + #[getter] + fn sorted(&self) -> Option { + self.sorted + } +} + +impl From for Type { + fn from(value: ValueType) -> Self { + Self { + primitive: value.primitive.map(Primitive::from), + shape: Shape::from(value.shape), + attribute: value.attribute.map(Attribute::from), + sorted: value.sorted, + } + } +} + +impl Type { + pub fn to_core(&self) -> ValueType { + ValueType { + primitive: self.primitive.map(qroissant_core::Primitive::from), + shape: qroissant_core::Shape::from(self.shape), + attribute: self.attribute.map(qroissant_core::Attribute::from), + sorted: self.sorted, + } + } + + pub fn primitive_value(&self) -> Option { + self.primitive + } + + pub fn shape_value(&self) -> Shape { + self.shape + } + + pub fn attribute_value(&self) -> Option { + self.attribute + } + + pub fn sorted_value(&self) -> Option { + self.sorted + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct MessageHeader { + encoding: Encoding, + message_type: MessageType, + compression: Compression, + size: usize, +} + +#[pymethods] +impl MessageHeader { + #[new] + fn new( + encoding: Encoding, + message_type: MessageType, + compression: Compression, + size: usize, + ) -> Self { + Self { + encoding, + message_type, + compression, + size, + } + } + + #[getter] + fn encoding(&self) -> Encoding { + self.encoding + } + + #[getter] + fn message_type(&self) -> MessageType { + self.message_type + } + + #[getter] + fn compression(&self) -> Compression { + self.compression + } + + #[getter] + fn size(&self) -> usize { + self.size + } +} + +impl From for MessageHeader { + fn from(value: qroissant_core::MessageHeader) -> Self { + Self { + encoding: Encoding::from(value.encoding()), + message_type: MessageType::from(value.message_type()), + compression: Compression::from(value.compression()), + size: value.size(), + } + } +} + +impl MessageHeader { + pub fn new_native( + encoding: Encoding, + message_type: MessageType, + compression: Compression, + size: usize, + ) -> Self { + Self { + encoding, + message_type, + compression, + size, + } + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Endpoint { + scheme: String, + host: Option, + port: Option, + path: Option, + username: Option, + password: Option, + timeout_ms: Option, +} + +#[pymethods] +impl Endpoint { + #[staticmethod] + #[pyo3(signature = (host, port, *, username=None, password=None, timeout_ms=None))] + fn tcp( + host: String, + port: u16, + username: Option, + password: Option, + timeout_ms: Option, + ) -> Self { + Self { + scheme: "tcp".to_string(), + host: Some(host), + port: Some(port), + path: None, + username, + password, + timeout_ms, + } + } + + #[staticmethod] + #[pyo3(signature = (path, *, username=None, password=None, timeout_ms=None))] + fn unix( + path: String, + username: Option, + password: Option, + timeout_ms: Option, + ) -> Self { + Self { + scheme: "unix".to_string(), + host: None, + port: None, + path: Some(path), + username, + password, + timeout_ms, + } + } + + #[getter] + fn scheme(&self) -> String { + self.scheme.clone() + } + + #[getter] + fn host(&self) -> Option { + self.host.clone() + } + + #[getter] + fn port(&self) -> Option { + self.port + } + + #[getter] + fn path(&self) -> Option { + self.path.clone() + } + + #[getter] + fn username(&self) -> Option { + self.username.clone() + } + + #[getter] + fn password(&self) -> Option { + self.password.clone() + } + + #[getter] + fn timeout_ms(&self) -> Option { + self.timeout_ms + } +} + +impl Endpoint { + pub fn validate(&self) -> Result<(), String> { + match self.scheme.as_str() { + "tcp" => { + if self.host.is_none() { + return Err("tcp endpoints require a host".to_string()); + } + if self.port.is_none() { + return Err("tcp endpoints require a port".to_string()); + } + } + "unix" => { + if self.path.is_none() { + return Err("unix endpoints require a path".to_string()); + } + } + _ => { + return Err(format!("unsupported endpoint scheme {:?}", self.scheme)); + } + } + Ok(()) + } + + pub fn username_deref(&self) -> Option<&str> { + self.username.as_deref() + } + + pub fn password_deref(&self) -> Option<&str> { + self.password.as_deref() + } + + pub fn scheme_value(&self) -> &str { + &self.scheme + } + + pub fn host_value(&self) -> Option<&str> { + self.host.as_deref() + } + + pub fn port_value(&self) -> Option { + self.port + } + + pub fn path_value(&self) -> Option<&str> { + self.path.as_deref() + } + + pub fn timeout_ms_value(&self) -> Option { + self.timeout_ms + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PoolOptions { + max_size: u32, + min_idle: Option, + checkout_timeout_ms: u64, + idle_timeout_ms: Option, + max_lifetime_ms: Option, + test_on_checkout: bool, + healthcheck_query: Option, + retry_attempts: u32, + retry_backoff_ms: u64, +} + +#[pymethods] +impl PoolOptions { + #[new] + #[pyo3(signature = (*, max_size=10, min_idle=None, checkout_timeout_ms=30_000, idle_timeout_ms=None, max_lifetime_ms=None, test_on_checkout=true, healthcheck_query=Some("::".to_string()), retry_attempts=0, retry_backoff_ms=0))] + #[allow(clippy::too_many_arguments)] + fn new( + max_size: u32, + min_idle: Option, + checkout_timeout_ms: u64, + idle_timeout_ms: Option, + max_lifetime_ms: Option, + test_on_checkout: bool, + healthcheck_query: Option, + retry_attempts: u32, + retry_backoff_ms: u64, + ) -> Self { + Self { + max_size, + min_idle, + checkout_timeout_ms, + idle_timeout_ms, + max_lifetime_ms, + test_on_checkout, + healthcheck_query, + retry_attempts, + retry_backoff_ms, + } + } + + #[getter] + fn max_size(&self) -> u32 { + self.max_size + } + + #[getter] + fn min_idle(&self) -> Option { + self.min_idle + } + + #[getter] + fn checkout_timeout_ms(&self) -> u64 { + self.checkout_timeout_ms + } + + #[getter] + fn idle_timeout_ms(&self) -> Option { + self.idle_timeout_ms + } + + #[getter] + fn max_lifetime_ms(&self) -> Option { + self.max_lifetime_ms + } + + #[getter] + fn test_on_checkout(&self) -> bool { + self.test_on_checkout + } + + #[getter] + fn healthcheck_query(&self) -> Option { + self.healthcheck_query.clone() + } + + #[getter] + fn retry_attempts(&self) -> u32 { + self.retry_attempts + } + + #[getter] + fn retry_backoff_ms(&self) -> u64 { + self.retry_backoff_ms + } +} + +impl Default for PoolOptions { + fn default() -> Self { + Self { + max_size: 10, + min_idle: None, + checkout_timeout_ms: 30_000, + idle_timeout_ms: None, + max_lifetime_ms: None, + test_on_checkout: true, + healthcheck_query: Some("::".to_string()), + retry_attempts: 0, + retry_backoff_ms: 0, + } + } +} + +impl PoolOptions { + pub fn max_size_value(&self) -> u32 { + self.max_size + } + + pub fn min_idle_value(&self) -> Option { + self.min_idle + } + + pub fn checkout_timeout_ms_value(&self) -> u64 { + self.checkout_timeout_ms + } + + pub fn idle_timeout_ms_value(&self) -> Option { + self.idle_timeout_ms + } + + pub fn max_lifetime_ms_value(&self) -> Option { + self.max_lifetime_ms + } + + pub fn test_on_checkout_value(&self) -> bool { + self.test_on_checkout + } + + pub fn healthcheck_query_value(&self) -> Option<&str> { + self.healthcheck_query.as_deref() + } + + pub fn retry_attempts_value(&self) -> u32 { + self.retry_attempts + } + + pub fn retry_backoff_ms_value(&self) -> u64 { + self.retry_backoff_ms + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct DecodeOptions { + list_interpretation: ListInterpretation, + union_mode: UnionMode, + string_interpretation: StringInterpretation, + symbol_interpretation: SymbolInterpretation, + assume_symbol_utf8: bool, + parallel: bool, + preserve_original_body: bool, + validate_compressed_trailing_bytes: bool, + temporal_nulls: bool, + treat_infinity_as_null: bool, +} + +#[pymethods] +impl DecodeOptions { + #[staticmethod] + fn builder() -> DecodeOptionsBuilder { + DecodeOptionsBuilder::default() + } + + #[getter] + fn list_interpretation(&self) -> ListInterpretation { + self.list_interpretation + } + + #[getter] + fn union_mode(&self) -> UnionMode { + self.union_mode + } + + #[getter] + fn string_interpretation(&self) -> StringInterpretation { + self.string_interpretation + } + + #[getter] + fn symbol_interpretation(&self) -> SymbolInterpretation { + self.symbol_interpretation + } + + #[getter] + fn assume_symbol_utf8(&self) -> bool { + self.assume_symbol_utf8 + } + + #[getter] + fn parallel(&self) -> bool { + self.parallel + } + + #[getter] + fn preserve_original_body(&self) -> bool { + self.preserve_original_body + } + + #[getter] + fn validate_compressed_trailing_bytes(&self) -> bool { + self.validate_compressed_trailing_bytes + } + + #[getter] + fn temporal_nulls(&self) -> bool { + self.temporal_nulls + } + + #[getter] + fn get_treat_infinity_as_null(&self) -> bool { + self.treat_infinity_as_null + } +} + +impl Default for DecodeOptions { + fn default() -> Self { + Self { + list_interpretation: ListInterpretation::List, + union_mode: UnionMode::Dense, + string_interpretation: StringInterpretation::Utf8, + symbol_interpretation: SymbolInterpretation::Utf8, + assume_symbol_utf8: true, + parallel: true, + preserve_original_body: true, + validate_compressed_trailing_bytes: true, + temporal_nulls: true, + treat_infinity_as_null: false, + } + } +} + +impl DecodeOptions { + pub fn is_default(&self) -> bool { + self == &Self::default() + } + + pub(crate) fn list_interpretation_value(&self) -> ListInterpretation { + self.list_interpretation + } + + pub(crate) fn string_interpretation_value(&self) -> StringInterpretation { + self.string_interpretation + } + + pub(crate) fn symbol_interpretation_value(&self) -> SymbolInterpretation { + self.symbol_interpretation + } + + pub(crate) fn union_mode_value(&self) -> UnionMode { + self.union_mode + } + + pub(crate) fn treat_infinity_as_null(&self) -> bool { + self.treat_infinity_as_null + } + + pub(crate) fn parallel_value(&self) -> bool { + self.parallel + } + + pub(crate) fn assume_symbol_utf8_value(&self) -> bool { + self.assume_symbol_utf8 + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct DecodeOptionsBuilder { + options: DecodeOptions, +} + +#[pymethods] +impl DecodeOptionsBuilder { + fn with_list_interpretation(&self, value: ListInterpretation) -> Self { + let mut next = self.clone(); + next.options.list_interpretation = value; + next + } + + fn with_union_mode(&self, value: UnionMode) -> Self { + let mut next = self.clone(); + next.options.union_mode = value; + next + } + + fn with_string_interpretation(&self, value: StringInterpretation) -> Self { + let mut next = self.clone(); + next.options.string_interpretation = value; + next + } + + fn with_symbol_interpretation(&self, value: SymbolInterpretation) -> Self { + let mut next = self.clone(); + next.options.symbol_interpretation = value; + next + } + + fn with_assume_symbol_utf8(&self, value: bool) -> Self { + let mut next = self.clone(); + next.options.assume_symbol_utf8 = value; + next + } + + fn with_parallel(&self, value: bool) -> Self { + let mut next = self.clone(); + next.options.parallel = value; + next + } + + fn with_preserve_original_body(&self, value: bool) -> Self { + let mut next = self.clone(); + next.options.preserve_original_body = value; + next + } + + fn with_validate_compressed_trailing_bytes(&self, value: bool) -> Self { + let mut next = self.clone(); + next.options.validate_compressed_trailing_bytes = value; + next + } + + fn with_temporal_nulls(&self, value: bool) -> Self { + let mut next = self.clone(); + next.options.temporal_nulls = value; + next + } + + fn with_treat_infinity_as_null(&self, value: bool) -> Self { + let mut next = self.clone(); + next.options.treat_infinity_as_null = value; + next + } + + fn build(&self) -> DecodeOptions { + self.options.clone() + } + + #[getter] + fn list_interpretation(&self) -> ListInterpretation { + self.options.list_interpretation + } + + #[getter] + fn union_mode(&self) -> UnionMode { + self.options.union_mode + } + + #[getter] + fn string_interpretation(&self) -> StringInterpretation { + self.options.string_interpretation + } + + #[getter] + fn symbol_interpretation(&self) -> SymbolInterpretation { + self.options.symbol_interpretation + } + + #[getter] + fn assume_symbol_utf8(&self) -> bool { + self.options.assume_symbol_utf8 + } + + #[getter] + fn parallel(&self) -> bool { + self.options.parallel + } + + #[getter] + fn preserve_original_body(&self) -> bool { + self.options.preserve_original_body + } + + #[getter] + fn validate_compressed_trailing_bytes(&self) -> bool { + self.options.validate_compressed_trailing_bytes + } + + #[getter] + fn temporal_nulls(&self) -> bool { + self.options.temporal_nulls + } + + #[getter] + fn treat_infinity_as_null(&self) -> bool { + self.options.treat_infinity_as_null + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct EncodeField { + primitive: Option, + shape: Option, + attribute: Option, + sorted: Option, +} + +#[pymethods] +impl EncodeField { + #[staticmethod] + fn builder() -> EncodeFieldBuilder { + EncodeFieldBuilder::default() + } + + #[getter] + fn primitive(&self) -> Option { + self.primitive + } + + #[getter] + fn shape(&self) -> Option { + self.shape + } + + #[getter] + fn attribute(&self) -> Option { + self.attribute + } + + #[getter] + fn sorted(&self) -> Option { + self.sorted + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct EncodeFieldBuilder { + field: EncodeField, +} + +#[pymethods] +impl EncodeFieldBuilder { + fn with_primitive(&self, value: Primitive) -> Self { + let mut next = self.clone(); + next.field.primitive = Some(value); + next + } + + fn with_shape(&self, value: Shape) -> Self { + let mut next = self.clone(); + next.field.shape = Some(value); + next + } + + fn with_attribute(&self, value: Attribute) -> Self { + let mut next = self.clone(); + next.field.attribute = Some(value); + next + } + + fn with_sorted(&self, value: bool) -> Self { + let mut next = self.clone(); + next.field.sorted = Some(value); + next + } + + fn build(&self) -> EncodeField { + self.field.clone() + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct EncodeOptions { + primitive: Option, + shape: Option, + attribute: Option, + strict: bool, + fields: BTreeMap, +} + +#[pymethods] +impl EncodeOptions { + #[staticmethod] + fn builder() -> EncodeOptionsBuilder { + EncodeOptionsBuilder::default() + } + + #[getter] + fn primitive(&self) -> Option { + self.primitive + } + + #[getter] + fn shape(&self) -> Option { + self.shape + } + + #[getter] + fn attribute(&self) -> Option { + self.attribute + } + + #[getter] + fn strict(&self) -> bool { + self.strict + } + + fn field(&self, name: &str) -> Option { + self.fields.get(name).cloned() + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq, Default)] +pub struct EncodeOptionsBuilder { + options: EncodeOptions, +} + +#[pymethods] +impl EncodeOptionsBuilder { + fn with_primitive(&self, value: Primitive) -> Self { + let mut next = self.clone(); + next.options.primitive = Some(value); + next + } + + fn with_shape(&self, value: Shape) -> Self { + let mut next = self.clone(); + next.options.shape = Some(value); + next + } + + fn with_attribute(&self, value: Attribute) -> Self { + let mut next = self.clone(); + next.options.attribute = Some(value); + next + } + + fn with_strict(&self, value: bool) -> Self { + let mut next = self.clone(); + next.options.strict = value; + next + } + + fn with_field(&self, name: String, field: EncodeField) -> Self { + let mut next = self.clone(); + next.options.fields.insert(name, field); + next + } + + fn build(&self) -> EncodeOptions { + self.options.clone() + } +} + +impl EncodeOptions { + pub fn is_default(&self) -> bool { + self == &Self::default() + } +} + +#[pyclass(module = "qroissant", frozen, eq)] +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct PoolMetrics { + connections: u32, + idle_connections: u32, + max_size: u32, + min_idle: Option, + initialized: bool, + closed: bool, +} + +impl PoolMetrics { + pub fn new_native( + connections: u32, + idle_connections: u32, + max_size: u32, + min_idle: Option, + initialized: bool, + closed: bool, + ) -> Self { + Self { + connections, + idle_connections, + max_size, + min_idle, + initialized, + closed, + } + } +} + +#[pymethods] +impl PoolMetrics { + #[new] + fn new( + connections: u32, + idle_connections: u32, + max_size: u32, + min_idle: Option, + initialized: bool, + closed: bool, + ) -> Self { + Self { + connections, + idle_connections, + max_size, + min_idle, + initialized, + closed, + } + } + + #[getter] + fn connections(&self) -> u32 { + self.connections + } + + #[getter] + fn idle_connections(&self) -> u32 { + self.idle_connections + } + + #[getter] + fn max_size(&self) -> u32 { + self.max_size + } + + #[getter] + fn min_idle(&self) -> Option { + self.min_idle + } + + #[getter] + fn initialized(&self) -> bool { + self.initialized + } + + #[getter] + fn closed(&self) -> bool { + self.closed + } +} + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + Ok(()) +} diff --git a/crates/qroissant-python/src/values.rs b/crates/qroissant-python/src/values.rs new file mode 100644 index 0000000..21123b4 --- /dev/null +++ b/crates/qroissant-python/src/values.rs @@ -0,0 +1,925 @@ +use std::sync::Arc; + +use pyo3::exceptions::PyIndexError; +use pyo3::exceptions::PyKeyError; +use pyo3::exceptions::PyNotImplementedError; +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::PyAny; +use pyo3::types::PyBytes; +use pyo3::types::PyCapsule; +use pyo3::types::PyDict; +use pyo3::types::PyIterator; +use pyo3::types::PyList; +use pyo3::types::PyTuple; +use pyo3_arrow::ffi::ArrayIterator; +use pyo3_arrow::ffi::to_array_pycapsules; +use pyo3_arrow::ffi::to_stream_pycapsule; +use qroissant_arrow::IngestionError; +use qroissant_arrow::ProjectionOptions; +use qroissant_arrow::ingest_array; +use qroissant_arrow::ingest_record_batch; +use qroissant_arrow::ingest_record_batch_reader; +use qroissant_arrow::project; +use qroissant_arrow::project_table; +use qroissant_core::Atom as CoreAtom; +use qroissant_core::Dictionary as CoreDictionary; +use qroissant_core::List as CoreList; +use qroissant_core::Table as CoreTable; +use qroissant_core::Value as CoreValue; +use qroissant_core::Vector as CoreVector; +use qroissant_core::VectorData; + +use crate::errors::to_py_err; +use crate::types::Attribute; +use crate::types::Compression; +use crate::types::Encoding; +use crate::types::MessageType; +use crate::types::Primitive; +use crate::types::Shape; +use crate::types::Type; + +#[pyclass(subclass, module = "qroissant")] +#[derive(Clone, Debug)] +pub struct Value { + inner: CoreValue, + projection_opts: Arc, +} + +impl Value { + pub fn new(inner: CoreValue) -> Self { + Self { + inner, + projection_opts: Arc::new(ProjectionOptions::default()), + } + } + + pub fn new_with_opts(inner: CoreValue, opts: Arc) -> Self { + Self { + inner, + projection_opts: opts, + } + } + + pub fn inner(&self) -> &CoreValue { + &self.inner + } + + pub fn into_inner(self) -> CoreValue { + self.inner + } + + pub fn projection_opts(&self) -> &Arc { + &self.projection_opts + } +} + +#[pymethods] +impl Value { + #[getter] + fn qtype(&self) -> Type { + Type::from(self.inner.qtype()) + } + + #[getter] + fn primitive(&self) -> Option { + self.inner.qtype().primitive.map(Primitive::from) + } + + #[getter] + fn shape(&self) -> Shape { + Shape::from(self.inner.qtype().shape) + } + + #[getter] + fn attribute(&self) -> Option { + self.inner.qtype().attribute.map(Attribute::from) + } + + #[pyo3(signature = (*, options=None, encoding=Encoding::LittleEndian, message_type=MessageType::Asynchronous, compression=Compression::Uncompressed))] + fn serialize( + &self, + options: Option<&crate::types::EncodeOptions>, + encoding: Encoding, + message_type: MessageType, + compression: Compression, + ) -> PyResult> { + let inner = self.inner.clone(); + let options_clone = options.cloned(); + Python::attach(|py| { + let payload = py + .detach(|| { + crate::serde::encode_core_value_bytes( + &inner, + options_clone.as_ref(), + encoding, + message_type, + compression, + ) + }) + .map_err(to_py_err)?; + Ok(PyBytes::new(py, &payload).unbind()) + }) + } +} + +#[pyclass(extends = Value, module = "qroissant")] +#[derive(Clone, Debug)] +pub struct Atom; + +#[pymethods] +impl Atom { + #[new] + fn new(qtype: PyRef<'_, Type>, value: &Bound<'_, PyAny>) -> PyResult<(Self, Value)> { + let core = atom_from_python(&qtype, value)?; + Ok((Self, Value::new(CoreValue::Atom(core)))) + } + + fn as_py(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + match slf.as_super().inner() { + CoreValue::Atom(atom) => atom_to_python(py, atom), + _ => unreachable!("Atom instances always hold q atoms"), + } + } + + #[getter] + fn value(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + Self::as_py(slf, py) + } + + fn is_null(slf: PyRef<'_, Self>) -> bool { + use qroissant_kernels::nulls::*; + match slf.as_super().inner() { + CoreValue::Atom(atom) => match atom { + CoreAtom::Boolean(_) + | CoreAtom::Guid(_) + | CoreAtom::Byte(_) + | CoreAtom::Char(_) + | CoreAtom::Symbol(_) => false, + CoreAtom::Short(v) => *v == Q_NULL_SHORT, + CoreAtom::Int(v) => *v == Q_NULL_INT, + CoreAtom::Long(v) => *v == Q_NULL_LONG, + CoreAtom::Real(v) => v.is_nan(), + CoreAtom::Float(v) => v.is_nan(), + CoreAtom::Timestamp(v) => *v == Q_NULL_TIMESTAMP, + CoreAtom::Month(v) => *v == Q_NULL_MONTH, + CoreAtom::Date(v) => *v == Q_NULL_DATE, + CoreAtom::Datetime(v) => v.is_nan(), + CoreAtom::Timespan(v) => *v == Q_NULL_TIMESPAN, + CoreAtom::Minute(v) => *v == Q_NULL_MINUTE, + CoreAtom::Second(v) => *v == Q_NULL_SECOND, + CoreAtom::Time(v) => *v == Q_NULL_TIME, + }, + _ => unreachable!("Atom instances always hold q atoms"), + } + } + + fn is_infinite(slf: PyRef<'_, Self>) -> bool { + use qroissant_kernels::nulls::*; + match slf.as_super().inner() { + CoreValue::Atom(atom) => match atom { + CoreAtom::Boolean(_) + | CoreAtom::Guid(_) + | CoreAtom::Byte(_) + | CoreAtom::Char(_) + | CoreAtom::Symbol(_) => false, + CoreAtom::Short(v) => *v == Q_INF_SHORT || *v == Q_NINF_SHORT, + CoreAtom::Int(v) => *v == Q_INF_INT || *v == Q_NINF_INT, + CoreAtom::Long(v) => *v == Q_INF_LONG || *v == Q_NINF_LONG, + CoreAtom::Real(v) => v.is_infinite(), + CoreAtom::Float(v) => v.is_infinite(), + CoreAtom::Timestamp(v) => *v == Q_INF_TIMESTAMP || *v == Q_NINF_TIMESTAMP, + CoreAtom::Month(v) => *v == Q_INF_MONTH || *v == Q_NINF_MONTH, + CoreAtom::Date(v) => *v == Q_INF_DATE || *v == Q_NINF_DATE, + CoreAtom::Datetime(v) => v.is_infinite(), + CoreAtom::Timespan(v) => *v == Q_INF_TIMESPAN || *v == Q_NINF_TIMESPAN, + CoreAtom::Minute(v) => *v == Q_INF_MINUTE || *v == Q_NINF_MINUTE, + CoreAtom::Second(v) => *v == Q_INF_SECOND || *v == Q_NINF_SECOND, + CoreAtom::Time(v) => *v == Q_INF_TIME || *v == Q_NINF_TIME, + }, + _ => unreachable!("Atom instances always hold q atoms"), + } + } + + #[pyo3(signature = (requested_schema=None))] + fn __arrow_c_array__( + slf: PyRef<'_, Self>, + py: Python<'_>, + requested_schema: Option>, + ) -> PyResult> { + let schema_capsule: Option> = requested_schema + .map(|s| s.downcast_into::()) + .transpose()?; + let opts = slf.as_super().projection_opts().clone(); + let export = project(slf.as_super().inner(), &opts) + .map_err(|e| PyNotImplementedError::new_err(e.to_string()))?; + let capsules = + to_array_pycapsules(py, export.field, export.array.as_ref(), schema_capsule)?; + Ok(capsules.unbind()) + } + + fn __repr__(slf: PyRef<'_, Self>) -> String { + match slf.as_super().inner() { + CoreValue::Atom(atom) => crate::repr::format_atom(atom), + _ => unreachable!("Atom instances always hold q atoms"), + } + } + + fn __str__(slf: PyRef<'_, Self>) -> String { + Self::__repr__(slf) + } +} + +#[pyclass(extends = Value, module = "qroissant")] +#[derive(Clone, Debug)] +pub struct Vector; + +#[pymethods] +impl Vector { + #[new] + fn new(qtype: PyRef<'_, Type>, values: Option<&Bound<'_, PyAny>>) -> PyResult<(Self, Value)> { + let core = if let Some(values) = values { + vector_from_python(&qtype, values)? + } else { + let empty = PyList::empty(qtype.py()); + vector_from_python(&qtype, empty.as_any())? + }; + Ok((Self, Value::new(CoreValue::Vector(core)))) + } + + fn __len__(slf: PyRef<'_, Self>) -> usize { + match slf.as_super().inner() { + CoreValue::Vector(vector) => vector.len(), + _ => unreachable!("Vector instances always hold q vectors"), + } + } + + fn __iter__(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + let list = Self::to_list(slf, py)?; + let iter = PyIterator::from_object(list.bind(py).as_any())?; + Ok(iter.into_any().unbind()) + } + + fn __getitem__(slf: PyRef<'_, Self>, py: Python<'_>, index: isize) -> PyResult> { + let vector = match slf.as_super().inner() { + CoreValue::Vector(vector) => vector, + _ => unreachable!("Vector instances always hold q vectors"), + }; + let index = normalize_index(index, vector.len())?; + vector_item_to_python(py, vector, index) + } + + fn to_list(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + let vector = match slf.as_super().inner() { + CoreValue::Vector(vector) => vector, + _ => unreachable!("Vector instances always hold q vectors"), + }; + vector_to_pylist(py, vector) + } + + #[pyo3(signature = (requested_schema=None))] + fn __arrow_c_array__( + slf: PyRef<'_, Self>, + py: Python<'_>, + requested_schema: Option>, + ) -> PyResult> { + let schema_capsule: Option> = requested_schema + .map(|s| s.downcast_into::()) + .transpose()?; + let opts = slf.as_super().projection_opts().clone(); + let export = project(slf.as_super().inner(), &opts) + .map_err(|e| PyNotImplementedError::new_err(e.to_string()))?; + let capsules = + to_array_pycapsules(py, export.field, export.array.as_ref(), schema_capsule)?; + Ok(capsules.unbind()) + } + + fn __repr__(slf: PyRef<'_, Self>) -> String { + match slf.as_super().inner() { + CoreValue::Vector(vector) => crate::repr::format_vector(vector), + _ => unreachable!("Vector instances always hold q vectors"), + } + } + + fn __str__(slf: PyRef<'_, Self>) -> String { + Self::__repr__(slf) + } +} + +#[pyclass(extends = Value, module = "qroissant")] +#[derive(Clone, Debug)] +pub struct List; + +#[pymethods] +impl List { + #[new] + fn new(qtype: PyRef<'_, Type>, values: Option<&Bound<'_, PyAny>>) -> PyResult<(Self, Value)> { + let core = if let Some(values) = values { + list_from_python(&qtype, values)? + } else { + let empty = PyList::empty(qtype.py()); + list_from_python(&qtype, empty.as_any())? + }; + Ok((Self, Value::new(CoreValue::List(core)))) + } + + fn __len__(slf: PyRef<'_, Self>) -> usize { + match slf.as_super().inner() { + CoreValue::List(list) => list.len(), + _ => unreachable!("List instances always hold q lists"), + } + } + + fn __iter__(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + let list = Self::to_list(slf, py)?; + let iter = PyIterator::from_object(list.bind(py).as_any())?; + Ok(iter.into_any().unbind()) + } + + fn __getitem__(slf: PyRef<'_, Self>, py: Python<'_>, index: isize) -> PyResult> { + let list = match slf.as_super().inner() { + CoreValue::List(list) => list, + _ => unreachable!("List instances always hold q lists"), + }; + let index = normalize_index(index, list.len())?; + core_value_to_python(py, list.values()[index].clone()) + } + + fn to_list(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + let list = match slf.as_super().inner() { + CoreValue::List(list) => list, + _ => unreachable!("List instances always hold q lists"), + }; + let mut values = Vec::with_capacity(list.len()); + for value in list.values() { + values.push(core_value_to_python(py, value.clone())?); + } + Ok(PyList::new(py, values)?.unbind()) + } + + #[pyo3(signature = (requested_schema=None))] + fn __arrow_c_array__( + slf: PyRef<'_, Self>, + py: Python<'_>, + requested_schema: Option>, + ) -> PyResult> { + let schema_capsule: Option> = requested_schema + .map(|s| s.downcast_into::()) + .transpose()?; + let opts = slf.as_super().projection_opts().clone(); + let export = project(slf.as_super().inner(), &opts) + .map_err(|e| PyNotImplementedError::new_err(e.to_string()))?; + let capsules = + to_array_pycapsules(py, export.field, export.array.as_ref(), schema_capsule)?; + Ok(capsules.unbind()) + } + + fn __repr__(slf: PyRef<'_, Self>) -> String { + match slf.as_super().inner() { + CoreValue::List(list) => crate::repr::format_list(list), + _ => unreachable!("List instances always hold q lists"), + } + } + + fn __str__(slf: PyRef<'_, Self>) -> String { + Self::__repr__(slf) + } +} + +#[pyclass(extends = Value, module = "qroissant")] +#[derive(Clone, Debug)] +pub struct Dictionary; + +#[pymethods] +impl Dictionary { + #[new] + fn new( + qtype: PyRef<'_, Type>, + keys: &Bound<'_, PyAny>, + values: &Bound<'_, PyAny>, + ) -> PyResult<(Self, Value)> { + let core = dictionary_from_python(&qtype, keys, values)?; + Ok((Self, Value::new(CoreValue::Dictionary(core)))) + } + + #[getter] + fn keys(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + match slf.as_super().inner() { + CoreValue::Dictionary(dictionary) => { + core_value_to_python(py, dictionary.keys().clone()) + } + _ => unreachable!("Dictionary instances always hold q dictionaries"), + } + } + + #[getter] + fn values(slf: PyRef<'_, Self>, py: Python<'_>) -> PyResult> { + match slf.as_super().inner() { + CoreValue::Dictionary(dictionary) => { + core_value_to_python(py, dictionary.values().clone()) + } + _ => unreachable!("Dictionary instances always hold q dictionaries"), + } + } + + fn __len__(slf: PyRef<'_, Self>) -> usize { + match slf.as_super().inner() { + CoreValue::Dictionary(dictionary) => dictionary.len(), + _ => unreachable!("Dictionary instances always hold q dictionaries"), + } + } + + #[pyo3(signature = (requested_schema=None))] + fn __arrow_c_array__( + slf: PyRef<'_, Self>, + py: Python<'_>, + requested_schema: Option>, + ) -> PyResult> { + let schema_capsule: Option> = requested_schema + .map(|s| s.downcast_into::()) + .transpose()?; + let opts = slf.as_super().projection_opts().clone(); + let export = project(slf.as_super().inner(), &opts) + .map_err(|e| PyNotImplementedError::new_err(e.to_string()))?; + let capsules = + to_array_pycapsules(py, export.field, export.array.as_ref(), schema_capsule)?; + Ok(capsules.unbind()) + } + + fn __repr__(slf: PyRef<'_, Self>) -> String { + match slf.as_super().inner() { + CoreValue::Dictionary(dict) => crate::repr::format_dictionary(dict), + _ => unreachable!("Dictionary instances always hold q dictionaries"), + } + } + + fn __str__(slf: PyRef<'_, Self>) -> String { + Self::__repr__(slf) + } +} + +#[pyclass(extends = Value, module = "qroissant")] +#[derive(Clone, Debug)] +pub struct Table; + +#[pymethods] +impl Table { + #[new] + fn new(qtype: PyRef<'_, Type>, columns: Option<&Bound<'_, PyAny>>) -> PyResult<(Self, Value)> { + let core = if let Some(columns) = columns { + table_from_python(&qtype, columns)? + } else { + let empty = PyDict::new(qtype.py()); + table_from_python(&qtype, empty.as_any())? + }; + Ok((Self, Value::new(CoreValue::Table(core)))) + } + + #[getter] + fn columns(slf: PyRef<'_, Self>) -> PyResult> { + match slf.as_super().inner() { + CoreValue::Table(table) => table + .column_names() + .iter() + .map(|name| { + String::from_utf8(name.to_vec()).map_err(|_| { + PyValueError::new_err("q table column names must be valid UTF-8 for now") + }) + }) + .collect(), + _ => unreachable!("Table instances always hold q tables"), + } + } + + #[getter] + fn num_rows(slf: PyRef<'_, Self>) -> usize { + match slf.as_super().inner() { + CoreValue::Table(table) => table.len(), + _ => unreachable!("Table instances always hold q tables"), + } + } + + #[getter] + fn num_columns(slf: PyRef<'_, Self>) -> usize { + match slf.as_super().inner() { + CoreValue::Table(table) => table.num_columns(), + _ => unreachable!("Table instances always hold q tables"), + } + } + + fn column(slf: PyRef<'_, Self>, py: Python<'_>, name: &str) -> PyResult> { + match slf.as_super().inner() { + CoreValue::Table(table) => { + let needle = name.as_bytes(); + for (idx, candidate) in table.column_names().iter().enumerate() { + if candidate.as_ref() == needle { + return core_value_to_python(py, table.columns()[idx].clone()); + } + } + Err(PyKeyError::new_err(name.to_string())) + } + _ => unreachable!("Table instances always hold q tables"), + } + } + + #[pyo3(signature = (requested_schema=None))] + fn __arrow_c_stream__( + slf: PyRef<'_, Self>, + py: Python<'_>, + requested_schema: Option>, + ) -> PyResult> { + let schema_capsule: Option> = requested_schema + .map(|s| s.downcast_into::()) + .transpose()?; + let table = match slf.as_super().inner() { + qroissant_core::Value::Table(t) => t.clone(), + _ => unreachable!("Table instances always hold q tables"), + }; + let opts = slf.as_super().projection_opts().clone(); + let export = py + .detach(|| project_table(&table, &opts).map_err(|e| e.to_string())) + .map_err(|e| PyNotImplementedError::new_err(e))?; + let reader = ArrayIterator::new(vec![Ok(export.struct_array)], export.struct_field); + let capsule = to_stream_pycapsule(py, Box::new(reader), schema_capsule)?; + Ok(capsule.into_any().unbind()) + } + + fn __repr__(slf: PyRef<'_, Self>) -> String { + match slf.as_super().inner() { + CoreValue::Table(table) => crate::repr::format_table(table), + _ => unreachable!("Table instances always hold q tables"), + } + } + + fn __str__(slf: PyRef<'_, Self>) -> String { + Self::__repr__(slf) + } +} + +fn normalize_index(index: isize, len: usize) -> PyResult { + let len = len as isize; + let index = if index < 0 { len + index } else { index }; + if !(0..len).contains(&index) { + return Err(PyIndexError::new_err("index out of range")); + } + Ok(index as usize) +} + +fn bytes_or_utf8(value: &Bound<'_, PyAny>) -> PyResult> { + if let Ok(bytes) = value.extract::>() { + return Ok(bytes); + } + Ok(value.extract::()?.into_bytes()) +} + +fn atom_to_python(py: Python<'_>, atom: &CoreAtom) -> PyResult> { + match atom { + CoreAtom::Boolean(value) => Ok(value.into_pyobject(py)?.to_owned().unbind().into_any()), + CoreAtom::Guid(value) => Ok(PyBytes::new(py, value).unbind().into_any()), + CoreAtom::Byte(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Short(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Int(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Long(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Real(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Float(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Char(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Symbol(value) => Ok(PyBytes::new(py, value).unbind().into_any()), + CoreAtom::Timestamp(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Month(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Date(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Datetime(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Timespan(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Minute(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Second(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + CoreAtom::Time(value) => Ok(value.into_pyobject(py)?.unbind().into_any()), + } +} + +fn atom_from_python(qtype: &Type, value: &Bound<'_, PyAny>) -> PyResult { + ensure_shape(qtype, Shape::Atom)?; + let primitive = qtype + .primitive_value() + .ok_or_else(|| PyValueError::new_err("atom qtype requires a primitive"))?; + match primitive { + Primitive::Boolean => Ok(CoreAtom::Boolean(value.extract()?)), + Primitive::Guid => { + let bytes = value.extract::>()?; + let guid: [u8; 16] = bytes.try_into().map_err(|_| { + PyValueError::new_err("guid atoms must be backed by exactly 16 bytes") + })?; + Ok(CoreAtom::Guid(guid)) + } + Primitive::Byte => Ok(CoreAtom::Byte(value.extract()?)), + Primitive::Short => Ok(CoreAtom::Short(value.extract()?)), + Primitive::Int => Ok(CoreAtom::Int(value.extract()?)), + Primitive::Long => Ok(CoreAtom::Long(value.extract()?)), + Primitive::Real => Ok(CoreAtom::Real(value.extract()?)), + Primitive::Float => Ok(CoreAtom::Float(value.extract()?)), + Primitive::Char => Ok(CoreAtom::Char(extract_char_like(value)?)), + Primitive::Symbol => Ok(CoreAtom::Symbol(bytes::Bytes::from(bytes_or_utf8(value)?))), + Primitive::Timestamp => Ok(CoreAtom::Timestamp(value.extract()?)), + Primitive::Month => Ok(CoreAtom::Month(value.extract()?)), + Primitive::Date => Ok(CoreAtom::Date(value.extract()?)), + Primitive::Datetime => Ok(CoreAtom::Datetime(value.extract()?)), + Primitive::Timespan => Ok(CoreAtom::Timespan(value.extract()?)), + Primitive::Minute => Ok(CoreAtom::Minute(value.extract()?)), + Primitive::Second => Ok(CoreAtom::Second(value.extract()?)), + Primitive::Time => Ok(CoreAtom::Time(value.extract()?)), + Primitive::Mixed => Err(PyValueError::new_err("mixed atoms are not valid q values")), + } +} + +fn extract_char_like(value: &Bound<'_, PyAny>) -> PyResult { + if let Ok(byte) = value.extract::() { + return Ok(byte); + } + let bytes = value.extract::>()?; + let [byte] = <[u8; 1]>::try_from(bytes.as_slice()) + .map_err(|_| PyValueError::new_err("char values must be a single byte or integer"))?; + Ok(byte) +} + +fn vector_from_python(qtype: &Type, values: &Bound<'_, PyAny>) -> PyResult { + ensure_shape(qtype, Shape::Vector)?; + let primitive = qtype + .primitive_value() + .ok_or_else(|| PyValueError::new_err("vector qtype requires a primitive"))?; + let list = values + .cast::() + .map_err(|_| PyValueError::new_err("vector payloads must be Python lists"))?; + let attribute = qtype.attribute_value().unwrap_or(Attribute::None).into(); + let data = match primitive { + Primitive::Boolean => { + let bools: Vec = extract_list(list, |item| item.extract())?; + let bytes: Vec = bools.into_iter().map(|b| if b { 1 } else { 0 }).collect(); + VectorData::Boolean(bytes::Bytes::from(bytes)) + } + Primitive::Guid => VectorData::from_guids(&extract_list(list, |item| { + let bytes = item.extract::>()?; + bytes + .try_into() + .map_err(|_| PyValueError::new_err("guid vector elements must be exactly 16 bytes")) + })?), + Primitive::Byte => VectorData::Byte(bytes::Bytes::from(extract_list(list, |item| { + item.extract::() + })?)), + Primitive::Short => VectorData::from_i16s(&extract_list(list, |item| item.extract())?), + Primitive::Int => VectorData::from_i32s(&extract_list(list, |item| item.extract())?), + Primitive::Long => VectorData::from_i64s(&extract_list(list, |item| item.extract())?), + Primitive::Real => VectorData::from_f32s(&extract_list(list, |item| item.extract())?), + Primitive::Float => VectorData::from_f64s(&extract_list(list, |item| item.extract())?), + Primitive::Char => { + VectorData::Char(bytes::Bytes::from(extract_list(list, extract_char_like)?)) + } + Primitive::Symbol => VectorData::Symbol( + extract_list(list, bytes_or_utf8)? + .into_iter() + .map(bytes::Bytes::from) + .collect(), + ), + Primitive::Timestamp => { + VectorData::from_timestamps(&extract_list(list, |item| item.extract())?) + } + Primitive::Month => VectorData::from_months(&extract_list(list, |item| item.extract())?), + Primitive::Date => VectorData::from_dates(&extract_list(list, |item| item.extract())?), + Primitive::Datetime => { + VectorData::from_datetimes(&extract_list(list, |item| item.extract())?) + } + Primitive::Timespan => { + VectorData::from_timespans(&extract_list(list, |item| item.extract())?) + } + Primitive::Minute => VectorData::from_minutes(&extract_list(list, |item| item.extract())?), + Primitive::Second => VectorData::from_seconds(&extract_list(list, |item| item.extract())?), + Primitive::Time => VectorData::from_times(&extract_list(list, |item| item.extract())?), + Primitive::Mixed => { + return Err(PyValueError::new_err( + "mixed vectors must use List rather than Vector", + )); + } + }; + Ok(CoreVector::new(attribute, data)) +} + +fn list_from_python(qtype: &Type, values: &Bound<'_, PyAny>) -> PyResult { + ensure_shape(qtype, Shape::List)?; + let list = values + .cast::() + .map_err(|_| PyValueError::new_err("list payloads must be Python lists"))?; + let attribute = qtype.attribute_value().unwrap_or(Attribute::None).into(); + let mut inner = Vec::with_capacity(list.len()); + for item in list.iter() { + inner.push(python_to_core_value(&item)?); + } + Ok(CoreList::new(attribute, inner)) +} + +fn dictionary_from_python( + qtype: &Type, + keys: &Bound<'_, PyAny>, + values: &Bound<'_, PyAny>, +) -> PyResult { + ensure_shape(qtype, Shape::Dictionary)?; + let sorted = qtype.sorted_value().unwrap_or(false); + let dictionary = CoreDictionary::new( + sorted, + python_to_core_value(keys)?, + python_to_core_value(values)?, + ); + dictionary + .validate() + .map_err(|error| PyValueError::new_err(error.to_string()))?; + Ok(dictionary) +} + +fn table_from_python(qtype: &Type, columns: &Bound<'_, PyAny>) -> PyResult { + ensure_shape(qtype, Shape::Table)?; + let columns = columns + .cast::() + .map_err(|_| PyValueError::new_err("table payloads must be Python dicts"))?; + let attribute = qtype.attribute_value().unwrap_or(Attribute::None).into(); + let mut names = Vec::with_capacity(columns.len()); + let mut values = Vec::with_capacity(columns.len()); + for (name, column) in columns.iter() { + names.push(bytes::Bytes::from(name.extract::()?.into_bytes())); + values.push(python_to_core_value(&column)?); + } + let table = CoreTable::new(attribute, names, values); + table + .validate() + .map_err(|error| PyValueError::new_err(error.to_string()))?; + Ok(table) +} + +fn ensure_shape(qtype: &Type, expected: Shape) -> PyResult<()> { + if qtype.shape_value() != expected { + return Err(PyValueError::new_err(format!( + "qtype shape {:?} does not match {:?}", + qtype.shape_value(), + expected + ))); + } + Ok(()) +} + +fn extract_list(items: &Bound<'_, PyList>, convert: F) -> PyResult> +where + F: Fn(&Bound<'_, PyAny>) -> PyResult, +{ + let mut values = Vec::with_capacity(items.len()); + for item in items.iter() { + values.push(convert(&item)?); + } + Ok(values) +} + +fn vector_to_pylist(py: Python<'_>, vector: &CoreVector) -> PyResult> { + let len = vector.len(); + let mut values = Vec::with_capacity(len); + for index in 0..len { + values.push(vector_item_to_python(py, vector, index)?); + } + Ok(PyList::new(py, values)?.unbind()) +} + +fn vector_item_to_python(py: Python<'_>, vector: &CoreVector, index: usize) -> PyResult> { + let data = vector.data(); + match data { + VectorData::Boolean(values) => Ok((values[index] != 0) + .into_pyobject(py)? + .to_owned() + .unbind() + .into_any()), + VectorData::Guid(values) => { + let chunk = &values[index * 16..(index + 1) * 16]; + Ok(PyBytes::new(py, chunk).unbind().into_any()) + } + VectorData::Byte(values) => Ok(values[index].into_pyobject(py)?.unbind().into_any()), + VectorData::Short(_) => Ok(data.as_i16_slice()[index] + .into_pyobject(py)? + .unbind() + .into_any()), + VectorData::Int(_) + | VectorData::Month(_) + | VectorData::Date(_) + | VectorData::Minute(_) + | VectorData::Second(_) + | VectorData::Time(_) => Ok(data.as_i32_slice()[index] + .into_pyobject(py)? + .unbind() + .into_any()), + VectorData::Long(_) | VectorData::Timestamp(_) | VectorData::Timespan(_) => Ok(data + .as_i64_slice()[index] + .into_pyobject(py)? + .unbind() + .into_any()), + VectorData::Real(_) => Ok(data.as_f32_slice()[index] + .into_pyobject(py)? + .unbind() + .into_any()), + VectorData::Float(_) | VectorData::Datetime(_) => Ok(data.as_f64_slice()[index] + .into_pyobject(py)? + .unbind() + .into_any()), + VectorData::Char(values) => Ok(values[index].into_pyobject(py)?.unbind().into_any()), + VectorData::Symbol(values) => Ok(PyBytes::new(py, &values[index]).unbind().into_any()), + } +} + +fn map_ingestion_error(e: IngestionError) -> PyErr { + PyValueError::new_err(e.to_string()) +} + +pub fn python_to_core_value(value: &Bound<'_, PyAny>) -> PyResult { + // Try qroissant Value first (it also implements Arrow protocols, so must come first). + if let Ok(q_value) = value.extract::>() { + return Ok(q_value.inner().clone()); + } + + // Check Arrow stream protocol (record batches → table). + if value.hasattr("__arrow_c_stream__")? { + let capsule_obj = value.getattr("__arrow_c_stream__")?.call0()?; + let stream_capsule = capsule_obj.downcast::().map_err(PyErr::from)?; + let reader = + pyo3_arrow::PyRecordBatchReader::from_arrow_pycapsule(stream_capsule)?.into_reader()?; + let schema = reader.schema(); + let value = ingest_record_batch_reader(schema, reader).map_err(map_ingestion_error)?; + return Ok(value); + } + + // Check Arrow array protocol (single array or record batch). + if value.hasattr("__arrow_c_array__")? { + // Try extracting as a record batch first. + if let Ok(record_batch) = value.extract::() { + let batch = record_batch.into_inner(); + let value = ingest_record_batch(batch).map_err(map_ingestion_error)?; + return Ok(value); + } + // Fall back to plain array. + let array: pyo3_arrow::PyArray = value.extract()?; + let (array, field) = array.into_inner(); + let value = ingest_array(array, field.as_ref()).map_err(map_ingestion_error)?; + return Ok(value); + } + + Err(PyNotImplementedError::new_err( + "encoding non-qroissant values is not implemented yet; \ + pass a qroissant Value or an object implementing the Arrow protocol", + )) +} + +pub fn core_value_to_python(py: Python<'_>, value: CoreValue) -> PyResult> { + core_value_to_python_with_opts(py, value, Arc::new(ProjectionOptions::default())) +} + +pub fn core_value_to_python_with_opts( + py: Python<'_>, + value: CoreValue, + opts: Arc, +) -> PyResult> { + match value { + CoreValue::Atom(atom) => Ok(Py::new( + py, + (Atom, Value::new_with_opts(CoreValue::Atom(atom), opts)), + )? + .into_any()), + CoreValue::Vector(vector) => Ok(Py::new( + py, + ( + Vector, + Value::new_with_opts(CoreValue::Vector(vector), opts), + ), + )? + .into_any()), + CoreValue::List(list) => Ok(Py::new( + py, + (List, Value::new_with_opts(CoreValue::List(list), opts)), + )? + .into_any()), + CoreValue::Dictionary(dictionary) => Ok(Py::new( + py, + ( + Dictionary, + Value::new_with_opts(CoreValue::Dictionary(dictionary), opts), + ), + )? + .into_any()), + CoreValue::Table(table) => Ok(Py::new( + py, + (Table, Value::new_with_opts(CoreValue::Table(table), opts)), + )? + .into_any()), + CoreValue::UnaryPrimitive { opcode } => { + Ok(Py::new(py, Value::new(CoreValue::UnaryPrimitive { opcode }))?.into_any()) + } + } +} + +pub fn register(module: &Bound<'_, PyModule>) -> PyResult<()> { + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + module.add_class::()?; + Ok(()) +} diff --git a/crates/qroissant-transport/Cargo.toml b/crates/qroissant-transport/Cargo.toml new file mode 100644 index 0000000..1ad6bdc --- /dev/null +++ b/crates/qroissant-transport/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "qroissant-transport" +version.workspace = true +edition.workspace = true +license.workspace = true +publish = false + +[lib] +name = "qroissant_transport" +path = "src/lib.rs" + +[dependencies] +bytes = "1.11.1" +qroissant-core = { path = "../qroissant-core" } +tokio = { workspace = true, features = ["io-util", "net", "time"] } +futures = { workspace = true } diff --git a/crates/qroissant-transport/src/asynchronous.rs b/crates/qroissant-transport/src/asynchronous.rs new file mode 100644 index 0000000..00090f2 --- /dev/null +++ b/crates/qroissant-transport/src/asynchronous.rs @@ -0,0 +1,475 @@ +use std::pin::Pin; +use std::task::Context; +use std::task::Poll; +use std::time::Duration; + +use qroissant_core::Compression; +use qroissant_core::HEADER_LEN; +use qroissant_core::MessageHeader; +use qroissant_core::StreamingDecompressor; +use qroissant_core::read_message_length; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; +use tokio::io::ReadBuf; +use tokio::net::TcpStream; +#[cfg(unix)] +use tokio::net::UnixStream; + +use crate::TransportError; +use crate::TransportResult; +use crate::synchronous::CLIENT_CAPABILITY; + +pub enum AsyncTransport { + Tcp(TcpStream), + #[cfg(unix)] + Unix(UnixStream), +} + +impl AsyncRead for AsyncTransport { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + Self::Tcp(stream) => Pin::new(stream).poll_read(cx, buf), + #[cfg(unix)] + Self::Unix(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for AsyncTransport { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + Self::Tcp(stream) => Pin::new(stream).poll_write(cx, buf), + #[cfg(unix)] + Self::Unix(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Self::Tcp(stream) => Pin::new(stream).poll_flush(cx), + #[cfg(unix)] + Self::Unix(stream) => Pin::new(stream).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Self::Tcp(stream) => Pin::new(stream).poll_shutdown(cx), + #[cfg(unix)] + Self::Unix(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +impl AsyncTransport { + pub async fn shutdown(&mut self) -> std::io::Result<()> { + match self { + Self::Tcp(stream) => stream.shutdown().await, + #[cfg(unix)] + Self::Unix(stream) => stream.shutdown().await, + } + } + + pub fn take_error(&self) -> std::io::Result> { + match self { + Self::Tcp(stream) => stream.take_error(), + #[cfg(unix)] + Self::Unix(stream) => stream.take_error(), + } + } +} + +pub struct AsyncPooledTransport { + transport: AsyncTransport, + broken: bool, +} + +impl AsyncPooledTransport { + pub fn new(transport: AsyncTransport) -> Self { + Self { + transport, + broken: false, + } + } + + pub fn mark_broken(&mut self) { + self.broken = true; + } + + pub fn is_broken(&self) -> bool { + self.broken || self.transport.take_error().ok().flatten().is_some() + } + + pub fn transport_mut(&mut self) -> &mut AsyncTransport { + &mut self.transport + } +} + +/// A reader that transparently decompresses q IPC payloads as they are read. +pub struct DecompressingReader<'a, R> { + reader: &'a mut R, + decompressor: Option, + remaining_compressed: usize, + buffer: Vec, +} + +impl<'a, R: AsyncRead + Unpin> DecompressingReader<'a, R> { + pub fn new( + reader: &'a mut R, + decompressor: Option, + remaining_compressed: usize, + ) -> Self { + Self { + reader, + decompressor, + remaining_compressed, + buffer: vec![0_u8; 8192], + } + } +} + +impl<'a, R: AsyncRead + Unpin> AsyncRead for DecompressingReader<'a, R> { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let this = &mut *self; + + if let Some(decompressor) = &mut this.decompressor { + // If we have decompressed data available, yield it first. + if decompressor.unread_len() > 0 { + let chunk = decompressor.next_chunk(); + let to_copy = chunk.len().min(buf.remaining()); + buf.put_slice(&chunk[..to_copy]); + decompressor.consume(to_copy); + return Poll::Ready(Ok(())); + } + + // If decompression is complete and no more unread bytes, EOF. + if decompressor.is_complete() { + return Poll::Ready(Ok(())); + } + + // Otherwise, read more compressed data from the underlying reader. + if this.remaining_compressed > 0 { + let want = this.remaining_compressed.min(this.buffer.len()); + let mut read_buf = ReadBuf::new(&mut this.buffer[..want]); + match Pin::new(&mut this.reader).poll_read(cx, &mut read_buf) { + Poll::Ready(Ok(())) => { + let read = read_buf.filled().len(); + if read == 0 && want > 0 { + return Poll::Ready(Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "unexpected EOF reading compressed body", + ))); + } + this.remaining_compressed -= read; + decompressor.feed(read_buf.filled()).map_err(|e| { + std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()) + })?; + + // Recursive call to yield the newly decompressed bytes. + return self.poll_read(cx, buf); + } + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + } + } + Poll::Ready(Ok(())) + } else { + // Uncompressed path: direct read from underlying reader. + Pin::new(&mut this.reader).poll_read(cx, buf) + } + } +} + +impl AsyncRead for AsyncPooledTransport { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.transport).poll_read(cx, buf) + } +} + +impl AsyncWrite for AsyncPooledTransport { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.transport).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.transport).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.transport).poll_shutdown(cx) + } +} + +fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec { + let username = username.unwrap_or_default(); + let password = password.unwrap_or_default(); + let mut bytes = format!("{username}:{password}").into_bytes(); + bytes.push(CLIENT_CAPABILITY); + bytes.push(0); + bytes +} + +fn timeout_error(context: &str, timeout_ms: u64) -> TransportError { + TransportError::Io(std::io::Error::new( + std::io::ErrorKind::TimedOut, + format!("{context} timed out after {timeout_ms}ms"), + )) +} + +async fn run_with_timeout( + timeout_ms: Option, + context: &str, + future: F, +) -> TransportResult +where + F: std::future::Future>, +{ + match timeout_ms { + Some(timeout_ms) => tokio::time::timeout(Duration::from_millis(timeout_ms), future) + .await + .map_err(|_| timeout_error(context, timeout_ms))? + .map_err(TransportError::Io), + None => future.await.map_err(TransportError::Io), + } +} + +async fn perform_handshake( + stream: &mut S, + username: Option<&str>, + password: Option<&str>, + timeout_ms: Option, +) -> TransportResult +where + S: AsyncRead + AsyncWrite + Unpin, +{ + run_with_timeout( + timeout_ms, + "q IPC handshake write", + stream.write_all(&credentials_bytes(username, password)), + ) + .await?; + run_with_timeout(timeout_ms, "q IPC handshake flush", stream.flush()).await?; + + let mut capability = [0_u8; 1]; + run_with_timeout( + timeout_ms, + "q IPC handshake read", + stream.read_exact(&mut capability), + ) + .await?; + Ok(capability[0]) +} + +pub async fn connect_tcp_transport( + host: &str, + port: u16, + username: Option<&str>, + password: Option<&str>, + timeout_ms: Option, +) -> TransportResult { + let mut stream = + run_with_timeout(timeout_ms, "TCP connect", TcpStream::connect((host, port))).await?; + stream.set_nodelay(true)?; + perform_handshake(&mut stream, username, password, timeout_ms).await?; + Ok(AsyncTransport::Tcp(stream)) +} + +#[cfg(unix)] +pub async fn connect_unix_transport( + path: &str, + username: Option<&str>, + password: Option<&str>, + timeout_ms: Option, +) -> TransportResult { + let mut stream = + run_with_timeout(timeout_ms, "Unix socket connect", UnixStream::connect(path)).await?; + perform_handshake(&mut stream, username, password, timeout_ms).await?; + Ok(AsyncTransport::Unix(stream)) +} + +pub async fn read_frame(stream: &mut S) -> TransportResult> +where + S: AsyncRead + Unpin, +{ + let mut header = [0_u8; HEADER_LEN]; + stream.read_exact(&mut header).await?; + let message_length = read_message_length(&header) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + let mut frame = vec![0_u8; message_length]; + frame[..HEADER_LEN].copy_from_slice(&header); + stream.read_exact(&mut frame[HEADER_LEN..]).await?; + Ok(frame) +} + +pub async fn request_frame_over(stream: &mut S, payload: &[u8]) -> TransportResult> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + stream.write_all(payload).await?; + stream.flush().await?; + read_frame(stream).await +} + +pub async fn begin_streaming_frame_over( + stream: &mut S, + payload: &[u8], +) -> TransportResult<([u8; HEADER_LEN], usize)> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + stream.write_all(payload).await?; + stream.flush().await?; + + let mut header = [0_u8; HEADER_LEN]; + stream.read_exact(&mut header).await?; + let message_length = read_message_length(&header) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + Ok((header, message_length - HEADER_LEN)) +} + +/// Async variant of [`crate::synchronous::request_frame_streaming_over`]. +/// +/// Sends a payload and reads the response frame, using streaming decompression +/// when the response is compressed. +pub async fn request_frame_streaming_over( + stream: &mut S, + payload: &[u8], +) -> TransportResult> +where + S: AsyncRead + AsyncWrite + Unpin, +{ + stream.write_all(payload).await?; + stream.flush().await?; + + // Read the 8-byte header. + let mut header_bytes = [0_u8; HEADER_LEN]; + stream.read_exact(&mut header_bytes).await?; + + let header = MessageHeader::from_bytes(header_bytes) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + let body_len = header.body_len(); + + if header.compression() == Compression::Uncompressed { + let mut frame = vec![0_u8; header.size()]; + frame[..HEADER_LEN].copy_from_slice(&header_bytes); + stream.read_exact(&mut frame[HEADER_LEN..]).await?; + return Ok(frame); + } + + // Compressed frame: read the 4-byte size prefix first. + if body_len < 4 { + return Err(TransportError::Protocol( + "compressed body must be at least 4 bytes for size prefix".to_string(), + )); + } + + let mut size_prefix = [0_u8; 4]; + stream.read_exact(&mut size_prefix).await?; + + let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding()) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + + // Read the remaining compressed body in chunks. + let remaining = body_len - 4; + let mut total_read = 0_usize; + let mut chunk = vec![0_u8; 8192]; + + while total_read < remaining { + let want = (remaining - total_read).min(chunk.len()); + stream.read_exact(&mut chunk[..want]).await?; + decompressor + .feed(&chunk[..want]) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + total_read += want; + } + + if !decompressor.is_complete() { + return Err(TransportError::Protocol( + "streaming decompression did not complete after reading entire body".to_string(), + )); + } + + let decompressed = decompressor + .finish() + .map_err(|error| TransportError::Protocol(error.to_string()))?; + + // Reconstruct as an uncompressed frame. + let new_size = HEADER_LEN + decompressed.len(); + let new_header = qroissant_core::MessageHeader::new( + header.encoding(), + header.message_type(), + Compression::Uncompressed, + new_size, + ) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + + let mut frame = Vec::with_capacity(new_size); + frame.extend_from_slice( + &new_header + .to_bytes() + .map_err(|error| TransportError::Protocol(error.to_string()))?, + ); + frame.extend_from_slice(&decompressed); + Ok(frame) +} +use qroissant_core::pipelined::PipelinedReader; +use qroissant_core::pipelined::decode_value_async; +use qroissant_core::value::Value; + +pub async fn request_value_pipelined_over( + conn: &mut R, + payload: &[u8], +) -> TransportResult { + conn.write_all(payload).await.map_err(TransportError::Io)?; + conn.flush().await.map_err(TransportError::Io)?; + + let mut header_bytes = [0_u8; HEADER_LEN]; + conn.read_exact(&mut header_bytes) + .await + .map_err(TransportError::Io)?; + let header = + MessageHeader::parse(&header_bytes).map_err(|e| TransportError::Protocol(e.to_string()))?; + + let (decompressor, remaining_compressed) = if header.compression() != Compression::Uncompressed + { + let mut size_prefix = [0_u8; 4]; + conn.read_exact(&mut size_prefix) + .await + .map_err(TransportError::Io)?; + let decompressor = StreamingDecompressor::new(size_prefix, header.encoding()) + .map_err(|e| TransportError::Protocol(e.to_string()))?; + (Some(decompressor), header.body_len() - 4) + } else { + (None, header.body_len()) + }; + + let mut decomp_reader = DecompressingReader::new(conn, decompressor, remaining_compressed); + let mut pipelined_reader = PipelinedReader::new(&mut decomp_reader, header.encoding()) + .map_err(|e| TransportError::Protocol(e.to_string()))?; + + decode_value_async(&mut pipelined_reader) + .await + .map_err(|e| TransportError::Protocol(e.to_string())) +} diff --git a/crates/qroissant-transport/src/error.rs b/crates/qroissant-transport/src/error.rs new file mode 100644 index 0000000..a5ac0f4 --- /dev/null +++ b/crates/qroissant-transport/src/error.rs @@ -0,0 +1,42 @@ +use std::fmt; + +pub type TransportResult = Result; + +#[derive(Debug)] +pub enum TransportError { + Io(std::io::Error), + InvalidEndpoint(String), + InvalidQueryLength(usize), + Protocol(String), + Closed, +} + +impl fmt::Display for TransportError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Io(error) => error.fmt(f), + Self::InvalidEndpoint(message) => write!(f, "{message}"), + Self::InvalidQueryLength(length) => write!( + f, + "q query string length {length} exceeds 32-bit q IPC capacity" + ), + Self::Protocol(message) => write!(f, "{message}"), + Self::Closed => write!(f, "connection is closed"), + } + } +} + +impl std::error::Error for TransportError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Self::Io(error) => Some(error), + _ => None, + } + } +} + +impl From for TransportError { + fn from(value: std::io::Error) -> Self { + Self::Io(value) + } +} diff --git a/crates/qroissant-transport/src/lib.rs b/crates/qroissant-transport/src/lib.rs new file mode 100644 index 0000000..27e5537 --- /dev/null +++ b/crates/qroissant-transport/src/lib.rs @@ -0,0 +1,37 @@ +//! Shared q IPC transport primitives. + +mod asynchronous; +mod error; +mod synchronous; + +pub use asynchronous::AsyncPooledTransport; +pub use asynchronous::AsyncTransport; +pub use asynchronous::begin_streaming_frame_over as begin_streaming_frame_over_async; +pub use asynchronous::connect_tcp_transport as connect_tcp_transport_async; +#[cfg(unix)] +pub use asynchronous::connect_unix_transport as connect_unix_transport_async; +pub use asynchronous::read_frame as read_frame_async; +pub use asynchronous::request_frame_over as request_frame_over_async; +pub use asynchronous::request_frame_streaming_over as request_frame_streaming_over_async; +pub use asynchronous::request_value_pipelined_over as request_value_pipelined_over_async; +pub use error::TransportError; +pub use error::TransportResult; +pub use qroissant_core::HEADER_LEN as QIPC_HEADER_LEN; +pub use synchronous::CLIENT_CAPABILITY; +pub use synchronous::SyncConnection; +pub use synchronous::SyncPooledTransport; +pub use synchronous::SyncTransport; +pub use synchronous::begin_streaming_frame_over; +pub use synchronous::connect_tcp_transport; +#[cfg(unix)] +pub use synchronous::connect_unix_transport; +pub use synchronous::credentials_bytes; +pub use synchronous::encode_sync_query; +pub use synchronous::extract_q_error; +pub use synchronous::parse_message_header; +pub use synchronous::perform_handshake; +pub use synchronous::request_frame_over; +pub use synchronous::request_frame_streaming_over; +pub use synchronous::validate_response_frame; +pub use synchronous::validate_response_header; +pub use synchronous::validate_response_header_bytes; diff --git a/crates/qroissant-transport/src/synchronous.rs b/crates/qroissant-transport/src/synchronous.rs new file mode 100644 index 0000000..ade35e1 --- /dev/null +++ b/crates/qroissant-transport/src/synchronous.rs @@ -0,0 +1,420 @@ +use std::io::Read; +use std::io::Write; +use std::net::Shutdown; +use std::net::TcpStream; +#[cfg(unix)] +use std::os::unix::net::UnixStream; +use std::time::Duration; + +use qroissant_core::Attribute; +use qroissant_core::Compression; +use qroissant_core::Encoding; +use qroissant_core::Frame; +use qroissant_core::HEADER_LEN; +use qroissant_core::MessageHeader; +use qroissant_core::MessageType; +use qroissant_core::StreamingDecompressor; +use qroissant_core::Value; +use qroissant_core::Vector; +use qroissant_core::VectorData; +use qroissant_core::encode_message; +use qroissant_core::read_frame; +use qroissant_core::read_message_length; + +use crate::TransportError; +use crate::TransportResult; + +pub const CLIENT_CAPABILITY: u8 = 3; + +pub enum SyncTransport { + Tcp(TcpStream), + #[cfg(unix)] + Unix(UnixStream), +} + +impl SyncTransport { + pub fn shutdown(&mut self) -> std::io::Result<()> { + match self { + Self::Tcp(stream) => stream.shutdown(Shutdown::Both), + #[cfg(unix)] + Self::Unix(stream) => stream.shutdown(Shutdown::Both), + } + } + + pub fn take_error(&self) -> std::io::Result> { + match self { + Self::Tcp(stream) => stream.take_error(), + #[cfg(unix)] + Self::Unix(stream) => stream.take_error(), + } + } + + pub fn set_timeouts(&self, timeout_ms: Option) -> std::io::Result<()> { + let timeout = timeout_ms.map(Duration::from_millis); + match self { + Self::Tcp(stream) => { + stream.set_read_timeout(timeout)?; + stream.set_write_timeout(timeout)?; + stream.set_nodelay(true) + } + #[cfg(unix)] + Self::Unix(stream) => { + stream.set_read_timeout(timeout)?; + stream.set_write_timeout(timeout) + } + } + } +} + +impl Read for SyncTransport { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + match self { + Self::Tcp(stream) => stream.read(buf), + #[cfg(unix)] + Self::Unix(stream) => stream.read(buf), + } + } +} + +impl Write for SyncTransport { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + match self { + Self::Tcp(stream) => stream.write(buf), + #[cfg(unix)] + Self::Unix(stream) => stream.write(buf), + } + } + + fn flush(&mut self) -> std::io::Result<()> { + match self { + Self::Tcp(stream) => stream.flush(), + #[cfg(unix)] + Self::Unix(stream) => stream.flush(), + } + } +} + +pub struct SyncPooledTransport { + transport: SyncTransport, + broken: bool, +} + +impl SyncPooledTransport { + pub fn new(transport: SyncTransport) -> Self { + Self { + transport, + broken: false, + } + } + + pub fn mark_broken(&mut self) { + self.broken = true; + } + + pub fn is_broken(&self) -> bool { + self.broken || self.transport.take_error().ok().flatten().is_some() + } + + pub fn shutdown(&mut self) -> std::io::Result<()> { + self.transport.shutdown() + } +} + +impl Read for SyncPooledTransport { + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + self.transport.read(buf) + } +} + +impl Write for SyncPooledTransport { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.transport.write(buf) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.transport.flush() + } +} + +pub fn credentials_bytes(username: Option<&str>, password: Option<&str>) -> Vec { + let username = username.unwrap_or_default(); + let password = password.unwrap_or_default(); + let mut bytes = format!("{username}:{password}").into_bytes(); + bytes.push(CLIENT_CAPABILITY); + bytes.push(0); + bytes +} + +pub fn perform_handshake( + stream: &mut S, + username: Option<&str>, + password: Option<&str>, +) -> TransportResult { + stream.write_all(&credentials_bytes(username, password))?; + stream.flush()?; + + let mut capability = [0_u8; 1]; + stream.read_exact(&mut capability)?; + Ok(capability[0]) +} + +pub fn encode_sync_query(message: &str) -> TransportResult> { + let _ = i32::try_from(message.len()) + .map_err(|_| TransportError::InvalidQueryLength(message.len()))?; + let value = Value::Vector(Vector::new( + Attribute::None, + VectorData::Char(bytes::Bytes::copy_from_slice(message.as_bytes())), + )); + + encode_message( + &value, + Encoding::LittleEndian, + MessageType::Synchronous, + Compression::Uncompressed, + ) + .map_err(|error| TransportError::Protocol(error.to_string())) +} + +pub fn extract_q_error(frame_bytes: &[u8]) -> TransportResult> { + let frame = + Frame::parse(frame_bytes).map_err(|error| TransportError::Protocol(error.to_string()))?; + let body = frame.body(); + if body.first().copied() != Some(128) { + return Ok(None); + } + + let message = match body[1..].iter().position(|byte| *byte == 0) { + Some(end) => &body[1..1 + end], + None => &body[1..], + }; + Ok(Some(String::from_utf8_lossy(message).into_owned())) +} + +pub fn parse_message_header(header_bytes: [u8; HEADER_LEN]) -> TransportResult { + MessageHeader::from_bytes(header_bytes) + .map_err(|error| TransportError::Protocol(error.to_string())) +} + +pub fn validate_response_header(header: MessageHeader) -> TransportResult<()> { + if header.message_type() != MessageType::Response { + return Err(TransportError::Protocol(format!( + "expected a q response frame, received {:?}", + header.message_type() + ))); + } + + Ok(()) +} + +pub fn validate_response_header_bytes( + header_bytes: [u8; HEADER_LEN], +) -> TransportResult { + let header = parse_message_header(header_bytes)?; + validate_response_header(header)?; + Ok(header) +} + +pub fn validate_response_frame(frame_bytes: &[u8]) -> TransportResult { + let frame = + Frame::parse(frame_bytes).map_err(|error| TransportError::Protocol(error.to_string()))?; + let header = frame.header(); + validate_response_header(header)?; + Ok(header) +} + +pub fn connect_tcp_transport( + host: &str, + port: u16, + username: Option<&str>, + password: Option<&str>, + timeout_ms: Option, +) -> TransportResult { + let mut stream = SyncTransport::Tcp(TcpStream::connect((host, port))?); + stream.set_timeouts(timeout_ms)?; + perform_handshake(&mut stream, username, password)?; + Ok(stream) +} + +#[cfg(unix)] +pub fn connect_unix_transport( + path: &str, + username: Option<&str>, + password: Option<&str>, + timeout_ms: Option, +) -> TransportResult { + let mut stream = SyncTransport::Unix(UnixStream::connect(path)?); + stream.set_timeouts(timeout_ms)?; + perform_handshake(&mut stream, username, password)?; + Ok(stream) +} + +pub fn request_frame_over( + stream: &mut S, + payload: &[u8], +) -> TransportResult> { + stream.write_all(payload)?; + stream.flush()?; + read_frame(stream).map_err(|error| TransportError::Protocol(error.to_string())) +} + +/// Sends a payload and reads the response frame, using streaming decompression +/// when the response is compressed. +/// +/// For compressed frames, the body is read in chunks and fed to a +/// [`StreamingDecompressor`] incrementally, overlapping network I/O with +/// decompression work. The returned frame is reconstructed as an +/// *uncompressed* frame so callers can decode it normally. +/// +/// For uncompressed frames, this behaves identically to [`request_frame_over`]. +pub fn request_frame_streaming_over( + stream: &mut S, + payload: &[u8], +) -> TransportResult> { + stream.write_all(payload)?; + stream.flush()?; + + // Read the 8-byte header. + let mut header_bytes = [0_u8; HEADER_LEN]; + stream.read_exact(&mut header_bytes)?; + + let header = parse_message_header(header_bytes)?; + let body_len = header.body_len(); + + if header.compression() == Compression::Uncompressed { + // Fast path: read entire uncompressed body. + let mut frame = vec![0_u8; header.size()]; + frame[..HEADER_LEN].copy_from_slice(&header_bytes); + stream.read_exact(&mut frame[HEADER_LEN..])?; + return Ok(frame); + } + + // Compressed frame: read the 4-byte size prefix first. + if body_len < 4 { + return Err(TransportError::Protocol( + "compressed body must be at least 4 bytes for size prefix".to_string(), + )); + } + + let mut size_prefix = [0_u8; 4]; + stream.read_exact(&mut size_prefix)?; + + let mut decompressor = StreamingDecompressor::new(size_prefix, header.encoding()) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + + // Read the remaining compressed body in chunks. + let remaining = body_len - 4; + let mut total_read = 0_usize; + let mut chunk = [0_u8; 8192]; + + while total_read < remaining { + let want = (remaining - total_read).min(chunk.len()); + stream.read_exact(&mut chunk[..want])?; + decompressor + .feed(&chunk[..want]) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + total_read += want; + } + + if !decompressor.is_complete() { + return Err(TransportError::Protocol( + "streaming decompression did not complete after reading entire body".to_string(), + )); + } + + let decompressed = decompressor + .finish() + .map_err(|error| TransportError::Protocol(error.to_string()))?; + + // Reconstruct as an uncompressed frame: header + decompressed body. + let new_size = HEADER_LEN + decompressed.len(); + let new_header = MessageHeader::new( + header.encoding(), + header.message_type(), + Compression::Uncompressed, + new_size, + ) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + + let mut frame = Vec::with_capacity(new_size); + frame.extend_from_slice( + &new_header + .to_bytes() + .map_err(|error| TransportError::Protocol(error.to_string()))?, + ); + frame.extend_from_slice(&decompressed); + Ok(frame) +} + +pub fn begin_streaming_frame_over( + stream: &mut S, + payload: &[u8], +) -> TransportResult<([u8; HEADER_LEN], usize)> { + stream.write_all(payload)?; + stream.flush()?; + + let mut header = [0_u8; HEADER_LEN]; + stream.read_exact(&mut header)?; + let message_length = read_message_length(&header) + .map_err(|error| TransportError::Protocol(error.to_string()))?; + Ok((header, message_length - HEADER_LEN)) +} + +pub struct SyncConnection { + transport: Option, +} + +impl SyncConnection { + pub fn connect_tcp( + host: &str, + port: u16, + username: Option<&str>, + password: Option<&str>, + timeout_ms: Option, + ) -> TransportResult { + Ok(Self { + transport: Some(connect_tcp_transport( + host, port, username, password, timeout_ms, + )?), + }) + } + + #[cfg(unix)] + pub fn connect_unix( + path: &str, + username: Option<&str>, + password: Option<&str>, + timeout_ms: Option, + ) -> TransportResult { + Ok(Self { + transport: Some(connect_unix_transport( + path, username, password, timeout_ms, + )?), + }) + } + + pub fn query_frame(&mut self, message: &str) -> TransportResult> { + let payload = encode_sync_query(message)?; + let transport = self.transport.as_mut().ok_or(TransportError::Closed)?; + let frame = request_frame_over(transport, &payload)?; + validate_response_frame(&frame)?; + Ok(frame) + } + + pub fn is_closed(&self) -> bool { + self.transport.is_none() + } + + pub fn close(&mut self) -> TransportResult<()> { + let Some(mut transport) = self.transport.take() else { + return Ok(()); + }; + transport.shutdown()?; + Ok(()) + } +} + +impl Drop for SyncConnection { + fn drop(&mut self) { + let _ = self.close(); + } +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dda447b --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = ["maturin>=1.8,<2.0"] +build-backend = "maturin" + +[project] +name = "qroissant" +version = "0.3.0" +description = "q/kdb+ IPC client library with Arrow-native Python interoperability" +readme = "README.md" +requires-python = ">=3.10" +license = "Apache-2.0" +license-files = [] +authors = [{ name = "qroissant contributors" }] +keywords = ["kdb", "q", "ipc", "arrow", "pyo3"] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Rust", +] + +[dependency-groups] +dev = [ + "maturin>=1.8,<2.0", + "polars>=1.39.3", + "pyarrow>=23.0.1", + "pytest>=8.3,<9.0", + "ruff>=0.11,<0.12", +] +docs = [ + "mkdocs>=1.6,<2.0", + "mkdocs-material>=9.6,<10.0", + "mkdocs-material-extensions>=1.3,<2.0", + "mkdocstrings[python]>=0.28,<1.0", + "mkdocs-autorefs>=1.3,<2.0", +] + +[tool.maturin] +manifest-path = "crates/qroissant-python/Cargo.toml" +python-source = "python" +module-name = "qroissant._native" +features = [] + +[tool.pytest.ini_options] +testpaths = ["tests"] +pythonpath = ["python"] + +[tool.ruff] +target-version = "py310" +line-length = 88 + +[tool.ruff.lint] +select = ["E", "F", "I", "UP", "B"] + diff --git a/python/qroissant/__init__.py b/python/qroissant/__init__.py new file mode 100644 index 0000000..6e1487e --- /dev/null +++ b/python/qroissant/__init__.py @@ -0,0 +1,68 @@ +"""Public Python API for qroissant.""" + +import importlib + +__all__ = [ + "AsyncConnection", + "AsyncPool", + "AsyncRawResponse", + "Atom", + "Attribute", + "Compression", + "Connection", + "DecodeError", + "DecodeOptions", + "DecodeOptionsBuilder", + "Dictionary", + "EncodeField", + "EncodeFieldBuilder", + "EncodeOptions", + "EncodeOptionsBuilder", + "Encoding", + "Endpoint", + "List", + "ListInterpretation", + "MessageHeader", + "MessageType", + "OperationError", + "Pool", + "PoolClosedError", + "PoolError", + "PoolMetrics", + "PoolOptions", + "Primitive", + "ProtocolError", + "QRuntimeError", + "QroissantError", + "RawResponse", + "Shape", + "StringInterpretation", + "SymbolInterpretation", + "Table", + "TransportError", + "Type", + "UnionMode", + "Value", + "Vector", + "__native_available__", + "FormattingOptions", + "FormattingOptionsBuilder", + "RowDisplay", + "decode", + "encode", + "get_formatting_options", + "reset_formatting_options", + "set_formatting_options", +] + +try: + _native = importlib.import_module("qroissant._native") + for _name in __all__: + if _name != "__native_available__": + globals()[_name] = getattr(_native, _name) + __native_available__ = True +except ImportError: # pragma: no cover + __native_available__ = False + +if not __native_available__: + __all__ = ["__native_available__"] diff --git a/python/qroissant/__init__.pyi b/python/qroissant/__init__.pyi new file mode 100644 index 0000000..68b6fcb --- /dev/null +++ b/python/qroissant/__init__.pyi @@ -0,0 +1,50 @@ +from qroissant._client import AsyncConnection as AsyncConnection +from qroissant._client import AsyncPool as AsyncPool +from qroissant._client import AsyncRawResponse as AsyncRawResponse +from qroissant._client import Connection as Connection +from qroissant._client import Pool as Pool +from qroissant._client import PoolMetrics as PoolMetrics +from qroissant._client import RawResponse as RawResponse +from qroissant._config import DecodeOptions as DecodeOptions +from qroissant._config import DecodeOptionsBuilder as DecodeOptionsBuilder +from qroissant._config import EncodeField as EncodeField +from qroissant._config import EncodeFieldBuilder as EncodeFieldBuilder +from qroissant._config import EncodeOptions as EncodeOptions +from qroissant._config import EncodeOptionsBuilder as EncodeOptionsBuilder +from qroissant._config import Endpoint as Endpoint +from qroissant._config import ListInterpretation as ListInterpretation +from qroissant._config import PoolOptions as PoolOptions +from qroissant._config import StringInterpretation as StringInterpretation +from qroissant._config import SymbolInterpretation as SymbolInterpretation +from qroissant._config import UnionMode as UnionMode +from qroissant._errors import DecodeError as DecodeError +from qroissant._errors import OperationError as OperationError +from qroissant._errors import PoolClosedError as PoolClosedError +from qroissant._errors import PoolError as PoolError +from qroissant._errors import ProtocolError as ProtocolError +from qroissant._errors import QroissantError as QroissantError +from qroissant._errors import QRuntimeError as QRuntimeError +from qroissant._errors import TransportError as TransportError +from qroissant._message import Compression as Compression +from qroissant._message import Encoding as Encoding +from qroissant._message import MessageHeader as MessageHeader +from qroissant._message import MessageType as MessageType +from qroissant._repr import FormattingOptions as FormattingOptions +from qroissant._repr import FormattingOptionsBuilder as FormattingOptionsBuilder +from qroissant._repr import RowDisplay as RowDisplay +from qroissant._repr import get_formatting_options as get_formatting_options +from qroissant._repr import reset_formatting_options as reset_formatting_options +from qroissant._repr import set_formatting_options as set_formatting_options +from qroissant._serde import decode as decode +from qroissant._serde import encode as encode +from qroissant._values import Atom as Atom +from qroissant._values import Attribute as Attribute +from qroissant._values import Dictionary as Dictionary +from qroissant._values import List as List +from qroissant._values import Primitive as Primitive +from qroissant._values import Shape as Shape +from qroissant._values import Table as Table +from qroissant._values import Type as Type +from qroissant._values import Value as Value +from qroissant._values import Vector as Vector + diff --git a/python/qroissant/_client.pyi b/python/qroissant/_client.pyi new file mode 100644 index 0000000..6915c38 --- /dev/null +++ b/python/qroissant/_client.pyi @@ -0,0 +1,453 @@ +from __future__ import annotations + +from types import TracebackType +from typing import Literal, overload + +from qroissant._config import DecodeOptions, Endpoint, PoolOptions +from qroissant._message import MessageHeader +from qroissant._values import Value + +class RawResponse: + """Read-only file-like wrapper over raw q IPC response bytes. + + Notes + ----- + Raw queries stream responses forward-only, decoding them on demand. + + Thread Safety + ------------- + **Not thread-safe.** Each ``RawResponse`` holds an exclusive lease on its + parent connection — the connection cannot be reused until the response is + fully consumed or explicitly closed. + """ + + def __enter__(self) -> RawResponse: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: ... + @property + def closed(self) -> bool: + """Check if the response is closed.""" + ... + @property + def header(self) -> MessageHeader: + """Retrieve the IPC message header.""" + ... + def close(self) -> None: + """Close the underlying response stream.""" + ... + def readable(self) -> bool: + """Return True if the stream is readable.""" + ... + def seekable(self) -> bool: + """Return True if the stream supports random access (returns False).""" + ... + def tell(self) -> int: + """Return the current stream position.""" + ... + def read(self, size: int | None = None, /) -> bytes: + """Read bytes from the response payload.""" + ... + def read1(self, size: int | None = None, /) -> bytes: + """Read bytes from the response payload with minimal blocking.""" + ... + def readinto(self, buffer: bytearray | memoryview, /) -> int: + """Read payload bytes into a pre-allocated buffer.""" + ... + def readinto1(self, buffer: bytearray | memoryview, /) -> int: + """Read payload bytes into a pre-allocated buffer with minimal blocking.""" + ... + def decode(self, *, options: DecodeOptions | None = None) -> Value: + """Decode the remaining payload bytes into a qroissant value.""" + ... + + +class AsyncRawResponse: + """Asynchronous read-only wrapper over raw q IPC response bytes.""" + + async def __aenter__(self) -> AsyncRawResponse: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: ... + @property + def closed(self) -> bool: + """Check if the response is closed.""" + ... + @property + def header(self) -> MessageHeader: + """Retrieve the IPC message header.""" + ... + async def close(self) -> None: + """Close the underlying response stream.""" + ... + async def read(self, size: int | None = None, /) -> bytes: + """Read bytes from the response payload asynchronously.""" + ... + async def read1(self, size: int | None = None, /) -> bytes: + """Read bytes with minimal blocking asynchronously.""" + ... + async def readinto(self, buffer: bytearray | memoryview, /) -> int: + """Read payload bytes into a buffer asynchronously.""" + ... + async def readinto1(self, buffer: bytearray | memoryview, /) -> int: + """Read payload bytes into a buffer with minimal blocking asynchronously.""" + ... + async def decode(self, *, options: DecodeOptions | None = None) -> Value: + """Decode the remaining payload bytes into a qroissant value asynchronously.""" + ... + + +class PoolMetrics: + """Snapshot of a pool's occupancy, configuration, and lifecycle state.""" + + @property + def connections(self) -> int: + """Total number of currently tracked connections.""" + ... + @property + def idle_connections(self) -> int: + """Number of tracked connections that are currently idle.""" + ... + @property + def max_size(self) -> int: + """Configured maximum pool size.""" + ... + @property + def min_idle(self) -> int | None: + """Configured minimum idle target.""" + ... + @property + def initialized(self) -> bool: + """Whether the underlying pool has been created yet.""" + ... + @property + def closed(self) -> bool: + """Whether the pool has been explicitly closed.""" + ... + + +class Connection: + """Synchronous connection capable of executing q IPC requests. + + Parameters + ---------- + endpoint : Endpoint + Addressing and transport configuration. + options : DecodeOptions | None, optional + Default options to apply to decoded queried wrapper objects. + + Thread Safety + ------------- + **Not thread-safe.** Use one connection per thread, or use :class:`Pool` + for safe multi-threaded access. + """ + + def __init__( + self, + endpoint: Endpoint, + *, + options: DecodeOptions | None = None, + ) -> None: ... + def __enter__(self) -> Connection: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: ... + def close(self) -> None: + """Close the underlying connection.""" + ... + @overload + def query( + self, + expr: str, + /, + *, + raw: Literal[False] = False, + decode: DecodeOptions | None = None, + ) -> Value: ... + @overload + def query( + self, + expr: str, + /, + *, + raw: Literal[True], + ) -> RawResponse: ... + def query( + self, + expr: str, + /, + *, + raw: bool = False, + decode: DecodeOptions | None = None, + ) -> Value | RawResponse: + """Execute a synchronous q expression. + + Parameters + ---------- + expr : str + The q expression to evaluate remotely. + raw : bool, default=False + If True, returns a `RawResponse` stream instead of decoding. + decode : DecodeOptions | None, optional + Decoding options for this specific query, overriding connection defaults. + + Returns + ------- + Value | RawResponse + Decoded wrapper value or a file-like raw response if `raw=True`. + """ + ... + + +class AsyncConnection: + """Asynchronous connection capable of executing q IPC requests. + + Parameters + ---------- + endpoint : Endpoint + Addressing and transport configuration. + options : DecodeOptions | None, optional + Default options to apply to decoded queried wrapper objects. + + Thread Safety + ------------- + **Not thread-safe.** Must be driven from a single async task. Use + :class:`AsyncPool` for concurrent access. + """ + + def __init__( + self, + endpoint: Endpoint, + *, + options: DecodeOptions | None = None, + ) -> None: ... + async def __aenter__(self) -> AsyncConnection: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: ... + async def close(self) -> None: + """Close the underlying connection.""" + ... + @overload + async def query( + self, + expr: str, + /, + *, + raw: Literal[False] = False, + decode: DecodeOptions | None = None, + ) -> Value: ... + @overload + async def query( + self, + expr: str, + /, + *, + raw: Literal[True], + ) -> AsyncRawResponse: ... + async def query( + self, + expr: str, + /, + *, + raw: bool = False, + decode: DecodeOptions | None = None, + ) -> Value | AsyncRawResponse: + """Execute an asynchronous q expression. + + Parameters + ---------- + expr : str + The q expression to evaluate remotely. + raw : bool, default=False + If True, returns a `AsyncRawResponse` stream instead of decoding. + decode : DecodeOptions | None, optional + Decoding options for this specific query, overriding connection defaults. + + Returns + ------- + Value | AsyncRawResponse + Decoded wrapper value or a file-like raw response if `raw=True`. + """ + ... + + +class Pool: + """Synchronous connection pool for issuing q IPC requests. + + Parameters + ---------- + endpoint : Endpoint + Addressing and transport configuration. + options : DecodeOptions | None, optional + Default options to apply to decoded queried wrapper objects. + pool : PoolOptions | None, optional + Pool lifecycle and retry configuration. + + Thread Safety + ------------- + **Thread-safe.** Designed for multi-threaded use — each call checks out a + connection, uses it, and returns it to the pool automatically. + """ + + def __init__( + self, + endpoint: Endpoint, + *, + options: DecodeOptions | None = None, + pool: PoolOptions | None = None, + ) -> None: ... + def __enter__(self) -> Pool: ... + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: ... + def close(self) -> None: + """Close the pool and reject future checkouts.""" + ... + def prewarm(self) -> PoolMetrics: + """Create and validate idle connections ahead of the next query.""" + ... + def metrics(self) -> PoolMetrics: + """Return a snapshot of occupancy, configuration, and lifecycle state.""" + ... + @overload + def query( + self, + expr: str, + /, + *, + raw: Literal[False] = False, + decode: DecodeOptions | None = None, + ) -> Value: ... + @overload + def query( + self, + expr: str, + /, + *, + raw: Literal[True], + ) -> RawResponse: ... + def query( + self, + expr: str, + /, + *, + raw: bool = False, + decode: DecodeOptions | None = None, + ) -> Value | RawResponse: + """Execute a pooled synchronous q expression. + + Parameters + ---------- + expr : str + The q expression to evaluate remotely. + raw : bool, default=False + If True, returns a `RawResponse` stream instead of decoding. + decode : DecodeOptions | None, optional + Decoding options for this specific query, overriding connection defaults. + + Returns + ------- + Value | RawResponse + Decoded wrapper value or a file-like raw response if `raw=True`. + """ + ... + + +class AsyncPool: + """Asynchronous connection pool for issuing q IPC requests. + + Parameters + ---------- + endpoint : Endpoint + Addressing and transport configuration. + options : DecodeOptions | None, optional + Default options to apply to decoded queried wrapper objects. + pool : PoolOptions | None, optional + Pool lifecycle and retry configuration. + + Thread Safety + ------------- + **Thread-safe.** Designed for concurrent use across multiple async tasks. + """ + + def __init__( + self, + endpoint: Endpoint, + *, + options: DecodeOptions | None = None, + pool: PoolOptions | None = None, + ) -> None: ... + async def __aenter__(self) -> AsyncPool: ... + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: ... + async def close(self) -> None: + """Close the pool and reject future checkouts.""" + ... + async def prewarm(self) -> PoolMetrics: + """Create and validate idle connections ahead of the next query.""" + ... + async def metrics(self) -> PoolMetrics: + """Return a snapshot of occupancy, configuration, and lifecycle state.""" + ... + @overload + async def query( + self, + expr: str, + /, + *, + raw: Literal[False] = False, + decode: DecodeOptions | None = None, + ) -> Value: ... + @overload + async def query( + self, + expr: str, + /, + *, + raw: Literal[True], + ) -> AsyncRawResponse: ... + async def query( + self, + expr: str, + /, + *, + raw: bool = False, + decode: DecodeOptions | None = None, + ) -> Value | AsyncRawResponse: + """Execute a pooled asynchronous q expression. + + Parameters + ---------- + expr : str + The q expression to evaluate remotely. + raw : bool, default=False + If True, returns a `AsyncRawResponse` stream instead of decoding. + decode : DecodeOptions | None, optional + Decoding options for this specific query, overriding connection defaults. + + Returns + ------- + Value | AsyncRawResponse + Decoded wrapper value or a file-like raw response if `raw=True`. + """ + ... diff --git a/python/qroissant/_config.pyi b/python/qroissant/_config.pyi new file mode 100644 index 0000000..304f569 --- /dev/null +++ b/python/qroissant/_config.pyi @@ -0,0 +1,416 @@ +from __future__ import annotations + +import enum + +from qroissant._values import Attribute, Primitive, Shape + +class Endpoint: + """Connection destination configuration for a q process. + + Notes + ----- + An endpoint describes how to connect to a q process, including the transport + protocol (TCP or Unix socket), address, and authentication credentials. + """ + + @staticmethod + def tcp( + host: str, + port: int, + *, + username: str | None = None, + password: str | None = None, + timeout_ms: int | None = None, + ) -> Endpoint: + """Create a TCP endpoint configuration. + + Parameters + ---------- + host : str + Hostname or IP address of the q process. + port : int + TCP port exposed by the q process. + username : str | None, optional + Username used during the q IPC handshake. + password : str | None, optional + Password used during the q IPC handshake. + timeout_ms : int | None, optional + Connection timeout in milliseconds. + + Returns + ------- + Endpoint + TCP endpoint configuration. + """ + ... + + @staticmethod + def unix( + path: str, + *, + username: str | None = None, + password: str | None = None, + timeout_ms: int | None = None, + ) -> Endpoint: + """Create a Unix-domain socket endpoint configuration. + + Parameters + ---------- + path : str + Filesystem path to the q Unix domain socket. + username : str | None, optional + Username used during the q IPC handshake. + password : str | None, optional + Password used during the q IPC handshake. + timeout_ms : int | None, optional + Connection timeout in milliseconds. + + Returns + ------- + Endpoint + Unix domain socket endpoint configuration. + """ + ... + + @property + def scheme(self) -> str: + """The transport scheme ('tcp' or 'unix').""" + ... + @property + def host(self) -> str | None: + """The hostname, if applicable.""" + ... + @property + def port(self) -> int | None: + """The port, if applicable.""" + ... + @property + def path(self) -> str | None: + """The socket path, if applicable.""" + ... + @property + def username(self) -> str | None: + """The configured username.""" + ... + @property + def password(self) -> str | None: + """The configured password.""" + ... + @property + def timeout_ms(self) -> int | None: + """The connection timeout in milliseconds.""" + ... + + +class PoolOptions: + """Connection pool configuration shared by sync and async pool types. + + Parameters + ---------- + max_size : int, default=10 + Maximum number of connections managed by the pool. + min_idle : int | None, optional + Minimum number of idle connections to retain when the pool is warmed. + checkout_timeout_ms : int, default=30000 + Maximum time spent waiting for a pooled connection checkout. + idle_timeout_ms : int | None, optional + Maximum idle lifetime for an unused pooled connection. + max_lifetime_ms : int | None, optional + Maximum total lifetime for a pooled connection. + test_on_checkout : bool, default=True + Whether connections should be validated before they are handed out. + healthcheck_query : str | None, default="::" + Optional q expression used to validate live pooled connections. + Set to ``None`` to disable active q health checks. + retry_attempts : int, default=0 + Number of retry attempts after the initial pooled query failure. + retry_backoff_ms : int, default=0 + Delay between retry attempts, expressed in milliseconds. + """ + + def __init__( + self, + *, + max_size: int = 10, + min_idle: int | None = None, + checkout_timeout_ms: int = 30_000, + idle_timeout_ms: int | None = None, + max_lifetime_ms: int | None = None, + test_on_checkout: bool = True, + healthcheck_query: str | None = "::", + retry_attempts: int = 0, + retry_backoff_ms: int = 0, + ) -> None: ... + @property + def max_size(self) -> int: + """Configured maximum number of connections managed by the pool.""" + ... + @property + def min_idle(self) -> int | None: + """Configured minimum number of idle connections to retain.""" + ... + @property + def checkout_timeout_ms(self) -> int: + """Configured maximum wait time for a connection checkout, in milliseconds.""" + ... + @property + def idle_timeout_ms(self) -> int | None: + """Configured maximum idle lifetime for an unused connection, in milliseconds.""" + ... + @property + def max_lifetime_ms(self) -> int | None: + """Configured maximum total lifetime for a connection, in milliseconds.""" + ... + @property + def test_on_checkout(self) -> bool: + """Whether connections are validated before being handed out.""" + ... + @property + def healthcheck_query(self) -> str | None: + """q expression used to validate live pooled connections, or ``None`` to disable.""" + ... + @property + def retry_attempts(self) -> int: + """Number of retry attempts after the initial pooled query failure.""" + ... + @property + def retry_backoff_ms(self) -> int: + """Delay between retry attempts, in milliseconds.""" + ... + + +class SymbolInterpretation(enum.Enum): + """Arrow representation used for q symbols.""" + UTF8 = ... + LARGE_UTF8 = ... + UTF8_VIEW = ... + DICTIONARY = ... + RAW_BYTES = ... + + +class ListInterpretation(enum.Enum): + """Arrow representation used for q lists.""" + LIST = ... + LARGE_LIST = ... + LIST_VIEW = ... + + +class StringInterpretation(enum.Enum): + """Arrow representation used for q char data.""" + UTF8 = ... + BINARY = ... + + +class UnionMode(enum.Enum): + """Arrow union representation used for mixed general lists.""" + DENSE = ... + SPARSE = ... + + +class DecodeOptions: + """Deserialization and Arrow conversion options. + + Notes + ----- + Construct instances through :meth:`builder` and finish with + :meth:`DecodeOptionsBuilder.build`. + """ + + @staticmethod + def builder() -> DecodeOptionsBuilder: + """Create a builder initialized with qroissant's default options.""" + ... + @property + def list_interpretation(self) -> ListInterpretation: + """Arrow container type used when projecting q lists.""" + ... + @property + def union_mode(self) -> UnionMode: + """Arrow union encoding used for mixed general lists.""" + ... + @property + def string_interpretation(self) -> StringInterpretation: + """Arrow type used when projecting q char vectors.""" + ... + @property + def symbol_interpretation(self) -> SymbolInterpretation: + """Arrow type used when projecting q symbol vectors and atoms.""" + ... + @property + def assume_symbol_utf8(self) -> bool: + """Whether q symbols are assumed to be valid UTF-8.""" + ... + @property + def parallel(self) -> bool: + """Whether table columns are decoded in parallel using multiple threads.""" + ... + @property + def preserve_original_body(self) -> bool: + """Whether the raw IPC payload bytes are retained on the decoded value.""" + ... + @property + def validate_compressed_trailing_bytes(self) -> bool: + """Whether trailing zero bytes after LZW-decompressed output are validated.""" + ... + @property + def temporal_nulls(self) -> bool: + """Whether temporal null sentinels are mapped to ``None`` in Arrow arrays.""" + ... + @property + def treat_infinity_as_null(self) -> bool: + """Whether ±∞ sentinels are mapped to ``None`` in Arrow arrays.""" + ... + + +class DecodeOptionsBuilder: + """Builder for :class:`DecodeOptions`.""" + + def with_list_interpretation( + self, value: ListInterpretation, / + ) -> DecodeOptionsBuilder: + """Set the Arrow container type for q list projection.""" + ... + def with_union_mode(self, value: UnionMode, /) -> DecodeOptionsBuilder: + """Set the Arrow union encoding for mixed general lists.""" + ... + def with_string_interpretation( + self, value: StringInterpretation, / + ) -> DecodeOptionsBuilder: + """Set the Arrow type for q char vector projection.""" + ... + def with_symbol_interpretation( + self, value: SymbolInterpretation, / + ) -> DecodeOptionsBuilder: + """Set the Arrow type for q symbol projection.""" + ... + def with_assume_symbol_utf8(self, value: bool, /) -> DecodeOptionsBuilder: + """Set whether q symbols are assumed to be valid UTF-8.""" + ... + def with_parallel(self, value: bool, /) -> DecodeOptionsBuilder: + """Set whether table columns are decoded in parallel.""" + ... + def with_preserve_original_body(self, value: bool, /) -> DecodeOptionsBuilder: + """Set whether the raw IPC payload bytes are retained on decoded values.""" + ... + def with_validate_compressed_trailing_bytes( + self, value: bool, / + ) -> DecodeOptionsBuilder: + """Set whether trailing zero bytes after LZW decompression are validated.""" + ... + def with_temporal_nulls(self, value: bool, /) -> DecodeOptionsBuilder: + """Set whether temporal null sentinels are mapped to ``None`` in Arrow arrays.""" + ... + def with_treat_infinity_as_null(self, value: bool, /) -> DecodeOptionsBuilder: + """Set whether ±∞ sentinels are mapped to ``None`` in Arrow arrays.""" + ... + def build(self) -> DecodeOptions: + """Finalize the builder into an immutable :class:`DecodeOptions` instance.""" + ... + + +class EncodeField: + """q serialization hints for a single Arrow value or table column.""" + + @staticmethod + def builder() -> EncodeFieldBuilder: + """Create a builder initialized with empty serialization hints.""" + ... + @property + def primitive(self) -> Primitive | None: + """Forced q primitive type for this field, or ``None`` to infer.""" + ... + @property + def shape(self) -> Shape | None: + """Forced q structural shape for this field, or ``None`` to infer.""" + ... + @property + def attribute(self) -> Attribute | None: + """q attribute to apply to this field, or ``None`` for no attribute.""" + ... + @property + def sorted(self) -> bool | None: + """Whether this field is sorted, or ``None`` to infer from the attribute.""" + ... + + +class EncodeFieldBuilder: + """Builder for :class:`EncodeField`.""" + + def with_primitive(self, value: Primitive, /) -> EncodeFieldBuilder: + """Force the q primitive type for this field.""" + ... + def with_shape(self, value: Shape, /) -> EncodeFieldBuilder: + """Force the q structural shape for this field.""" + ... + def with_attribute(self, value: Attribute, /) -> EncodeFieldBuilder: + """Set the q attribute to apply to this field.""" + ... + def with_sorted(self, value: bool, /) -> EncodeFieldBuilder: + """Set whether this field is sorted.""" + ... + def build(self) -> EncodeField: + """Finalize the builder into immutable per-field encoding hints.""" + ... + + +class EncodeOptions: + """Encoding options for Arrow-backed :func:`qroissant.encode` calls. + + Notes + ----- + qroissant resolves Arrow inputs in the following order: + + 1. qroissant metadata embedded on the Arrow field or schema + 2. per-field overrides attached through :meth:`EncodeOptionsBuilder.with_field` + 3. global defaults on :class:`EncodeOptions` + 4. generic Arrow type inference + """ + + @staticmethod + def builder() -> EncodeOptionsBuilder: + """Create a builder initialized with qroissant's default encoding policy.""" + ... + @property + def primitive(self) -> Primitive | None: + """Global fallback q primitive type, or ``None`` to infer per-field.""" + ... + @property + def shape(self) -> Shape | None: + """Global fallback q structural shape, or ``None`` to infer per-field.""" + ... + @property + def attribute(self) -> Attribute | None: + """Global fallback q attribute, or ``None`` for no attribute.""" + ... + @property + def strict(self) -> bool: + """Whether unknown Arrow types raise an error instead of being inferred.""" + ... + def field(self, name: str, /) -> EncodeField | None: + """Return per-field encoding hints for the given column name, if any.""" + ... + + +class EncodeOptionsBuilder: + """Builder for :class:`EncodeOptions`.""" + + def with_primitive(self, value: Primitive, /) -> EncodeOptionsBuilder: + """Set the global fallback q primitive type.""" + ... + def with_shape(self, value: Shape, /) -> EncodeOptionsBuilder: + """Set the global fallback q structural shape.""" + ... + def with_attribute(self, value: Attribute, /) -> EncodeOptionsBuilder: + """Set the global fallback q attribute.""" + ... + def with_strict(self, value: bool, /) -> EncodeOptionsBuilder: + """Set whether unknown Arrow types raise an error instead of being inferred.""" + ... + def with_field( + self, name: str, field: EncodeField, / + ) -> EncodeOptionsBuilder: + """Attach per-field encoding hints for the given column name.""" + ... + def build(self) -> EncodeOptions: + """Finalize the builder into immutable Arrow-to-q encoding options.""" + ... diff --git a/python/qroissant/_errors.pyi b/python/qroissant/_errors.pyi new file mode 100644 index 0000000..3075ee7 --- /dev/null +++ b/python/qroissant/_errors.pyi @@ -0,0 +1,24 @@ +class QroissantError(Exception): + """Base class for qroissant-specific failures.""" + +class DecodeError(QroissantError): + """Raised when q IPC payload decoding fails.""" + +class ProtocolError(QroissantError): + """Raised when q IPC framing or message validation fails.""" + +class TransportError(QroissantError): + """Raised when transport IO or socket operations fail.""" + +class OperationError(QroissantError): + """Raised when a qroissant operation is unsupported.""" + +class QRuntimeError(QroissantError): + """Raised when the remote q process returns an error response.""" + +class PoolError(QroissantError): + """Raised when connection pool management fails.""" + +class PoolClosedError(PoolError): + """Raised when a closed connection pool is used.""" + diff --git a/python/qroissant/_message.pyi b/python/qroissant/_message.pyi new file mode 100644 index 0000000..57ef7c7 --- /dev/null +++ b/python/qroissant/_message.pyi @@ -0,0 +1,43 @@ +from __future__ import annotations + +import enum + +class Encoding(enum.Enum): + """Endianness of the q IPC message payload.""" + LITTLE_ENDIAN = ... + BIG_ENDIAN = ... + + +class Compression(enum.Enum): + """Compression mode of the q IPC message payload.""" + UNCOMPRESSED = ... + COMPRESSED = ... + COMPRESSED_LARGE = ... + + +class MessageType(enum.Enum): + """IPC message type tag.""" + ASYNCHRONOUS = ... + SYNCHRONOUS = ... + RESPONSE = ... + + +class MessageHeader: + """Header information extracted from a q IPC frame.""" + @property + def encoding(self) -> Encoding: + """The endianness of the payload.""" + ... + @property + def message_type(self) -> MessageType: + """The message type tag.""" + ... + @property + def compression(self) -> Compression: + """The compression mode.""" + ... + @property + def size(self) -> int: + """The total size of the message frame in bytes.""" + ... + diff --git a/python/qroissant/_repr.pyi b/python/qroissant/_repr.pyi new file mode 100644 index 0000000..57cd28a --- /dev/null +++ b/python/qroissant/_repr.pyi @@ -0,0 +1,67 @@ +from __future__ import annotations + +import enum + +class RowDisplay(enum.Enum): + """Row selection strategy used by qroissant repr formatting.""" + + HEAD = ... + HEAD_TAIL = ... + + +class FormattingOptions: + """Formatting options for user-facing qroissant string representations. + + Notes + ----- + These options control how qroissant values render through :func:`str` and + :func:`repr`. Apply them process-wide through + :func:`qroissant.set_formatting_options`. + """ + + @staticmethod + def builder() -> FormattingOptionsBuilder: + """Create a builder initialized with qroissant's default formatting policy.""" + ... + @property + def max_rows(self) -> int: + """Maximum number of rows displayed in table repr.""" + ... + @property + def max_columns(self) -> int: + """Maximum number of columns displayed in table repr.""" + ... + @property + def row_display(self) -> RowDisplay: + """Row selection strategy used when the row limit is exceeded.""" + ... + + +class FormattingOptionsBuilder: + """Builder for :class:`FormattingOptions`.""" + + def with_max_rows(self, value: int, /) -> FormattingOptionsBuilder: + """Set the maximum number of rows displayed in table repr.""" + ... + def with_max_columns(self, value: int, /) -> FormattingOptionsBuilder: + """Set the maximum number of columns displayed in table repr.""" + ... + def with_row_display(self, value: RowDisplay, /) -> FormattingOptionsBuilder: + """Set the row selection strategy used when the row limit is exceeded.""" + ... + def build(self) -> FormattingOptions: + """Finalize the builder into immutable repr formatting options.""" + ... + + +def get_formatting_options() -> FormattingOptions: + """Get the current process-wide formatting options.""" + ... + +def set_formatting_options(options: FormattingOptions, /) -> None: + """Set the process-wide formatting options.""" + ... + +def reset_formatting_options() -> None: + """Reset the formatting options to their default values.""" + ... diff --git a/python/qroissant/_serde.pyi b/python/qroissant/_serde.pyi new file mode 100644 index 0000000..fbe563a --- /dev/null +++ b/python/qroissant/_serde.pyi @@ -0,0 +1,72 @@ +from __future__ import annotations + +from qroissant._config import DecodeOptions, EncodeOptions +from qroissant._message import Compression, Encoding, MessageType +from qroissant._values import Value + +def decode( + payload: object, + /, + *, + options: DecodeOptions | None = None, +) -> Value: + """Decode an IPC payload into a typed qroissant value. + + Parameters + ---------- + payload : object + Raw q IPC payload as ``bytes``, a read-only contiguous buffer, + or an object exposing such a buffer via ``.data``. + options : DecodeOptions | None, optional + Decoding and Arrow projection options. When omitted, the default + qroissant options are used. + + Returns + ------- + Value + Decoded q object represented by qroissant wrapper types. + + Raises + ------ + DecodeError + Raised when the IPC payload cannot be decoded into a valid q value. + ProtocolError + Raised when the frame header is malformed or unsupported. + """ + ... + + +def encode( + value: object, + /, + *, + options: EncodeOptions | None = None, + encoding: Encoding = Encoding.LITTLE_ENDIAN, + message_type: MessageType = MessageType.ASYNCHRONOUS, + compression: Compression = Compression.UNCOMPRESSED, +) -> bytes: + """Encode a qroissant value into q IPC bytes. + + Parameters + ---------- + value : object + A qroissant wrapper value, ``None`` for q null, any object exposing a + compatible ``serialize`` method, or an object implementing the Arrow + PyCapsule interface. + options : EncodeOptions | None, optional + Arrow-to-q encoding hints applied when ``value`` is encoded + from an Arrow array or stream. + encoding : Encoding, default=Encoding.LITTLE_ENDIAN + Endianness used in the generated IPC payload. + message_type : MessageType, default=MessageType.ASYNCHRONOUS + IPC message type tag written into the frame header. + compression : Compression, default=Compression.UNCOMPRESSED + Compression mode for payload encoding. + + Returns + ------- + bytes + Encoded q IPC payload. + """ + ... + diff --git a/python/qroissant/_values.pyi b/python/qroissant/_values.pyi new file mode 100644 index 0000000..4209727 --- /dev/null +++ b/python/qroissant/_values.pyi @@ -0,0 +1,234 @@ +from __future__ import annotations + +import enum +from collections.abc import Iterator + +from qroissant._config import EncodeOptions +from qroissant._message import Compression, Encoding, MessageType + +class Attribute(enum.Enum): + """q attribute applied to a value (e.g., sorted, unique, parted, grouped).""" + NONE = ... + SORTED = ... + UNIQUE = ... + PARTED = ... + GROUPED = ... + + +class Shape(enum.Enum): + """Structural shape of the q value.""" + ATOM = ... + VECTOR = ... + LIST = ... + DICTIONARY = ... + TABLE = ... + UNARY_PRIMITIVE = ... + + +class Primitive(enum.Enum): + """Underlying primitive domain of the q value.""" + BOOLEAN = ... + GUID = ... + BYTE = ... + SHORT = ... + INT = ... + LONG = ... + REAL = ... + FLOAT = ... + CHAR = ... + SYMBOL = ... + TIMESTAMP = ... + MONTH = ... + DATE = ... + DATETIME = ... + TIMESPAN = ... + MINUTE = ... + SECOND = ... + TIME = ... + MIXED = ... + + +class Type: + """Type descriptor for a qroissant value.""" + @property + def primitive(self) -> Primitive | None: + """The underlying q primitive, or None if mixed.""" + ... + @property + def shape(self) -> Shape: + """The structural shape of the value.""" + ... + @property + def attribute(self) -> Attribute | None: + """The q attribute applied to the value, if any.""" + ... + @property + def sorted(self) -> bool | None: + """Whether the value is sorted.""" + ... + + +class Value: + """Base class for decoded qroissant wrapper objects.""" + @property + def qtype(self) -> Type: + """The full q type descriptor.""" + ... + @property + def primitive(self) -> Primitive | None: + """The underlying q primitive, or None if mixed.""" + ... + @property + def shape(self) -> Shape: + """The structural shape of the value.""" + ... + @property + def attribute(self) -> Attribute | None: + """The q attribute applied to the value, if any.""" + ... + def serialize( + self, + *, + options: EncodeOptions | None = None, + encoding: Encoding = ..., + message_type: MessageType = ..., + compression: Compression = ..., + ) -> bytes: + """Serialize the value into a q IPC frame. + + Parameters + ---------- + options : EncodeOptions | None, optional + Arrow-to-q encoding hints. When omitted, the default encoding + policy is used. + encoding : Encoding, default=Encoding.LITTLE_ENDIAN + Endianness of the generated IPC payload. + message_type : MessageType, default=MessageType.ASYNCHRONOUS + IPC message type tag written into the frame header. + compression : Compression, default=Compression.UNCOMPRESSED + Compression mode applied to the payload. + + Returns + ------- + bytes + Encoded q IPC payload. + """ + ... + + +class Atom(Value): + """Scalar q value (e.g., integer, float, symbol).""" + def as_py(self) -> object: + """Convert the atom to a native Python type.""" + ... + @property + def value(self) -> object: + """The atom's native Python value (equivalent to ``as_py()``).""" + ... + def is_null(self) -> bool: + """Return ``True`` if this atom holds the q null sentinel for its type.""" + ... + def is_infinite(self) -> bool: + """``True`` if this atom is a q infinity sentinel (±∞).""" + ... + def __arrow_c_array__( + self, requested_schema: object | None = None, / + ) -> object: + """Export the atom as an Arrow array via the PyCapsule Protocol.""" + ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + + +class Vector(Value): + """Homogeneous q list of primitive values.""" + def __len__(self) -> int: + """Return the length of the vector.""" + ... + def __iter__(self) -> Iterator[object]: + """Iterate over the elements of the vector.""" + ... + def __getitem__(self, index: int, /) -> object: + """Get an element by index.""" + ... + def to_list(self) -> list[object]: + """Convert the vector to a Python list.""" + ... + def __arrow_c_array__( + self, requested_schema: object | None = None, / + ) -> object: + """Export the vector as an Arrow array via the PyCapsule Protocol.""" + ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + + +class List(Value): + """Heterogeneous/mixed q list of general values.""" + def __len__(self) -> int: + """Return the length of the list.""" + ... + def __iter__(self) -> Iterator[Value]: + """Iterate over the elements of the list.""" + ... + def __getitem__(self, index: int, /) -> Value: + """Get an element by index.""" + ... + def to_list(self) -> list[Value]: + """Convert the list to a Python list of values.""" + ... + def __arrow_c_array__( + self, requested_schema: object | None = None, / + ) -> object: + """Export the list as an Arrow array via the PyCapsule Protocol.""" + ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + + +class Dictionary(Value): + """q dictionary mapping keys to values.""" + @property + def keys(self) -> Value: + """The dictionary keys.""" + ... + @property + def values(self) -> Value: + """The dictionary values.""" + ... + def __len__(self) -> int: + """Return the number of key-value pairs.""" + ... + def __arrow_c_array__( + self, requested_schema: object | None = None, / + ) -> object: + """Export the dictionary as an Arrow StructArray via the PyCapsule Protocol.""" + ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + + +class Table(Value): + """q table representing tabular data.""" + @property + def columns(self) -> list[str]: + """The column names.""" + ... + @property + def num_rows(self) -> int: + """The number of rows.""" + ... + @property + def num_columns(self) -> int: + """The number of columns.""" + ... + def column(self, name: str, /) -> Value: + """Get a column by name.""" + ... + def __arrow_c_stream__( + self, requested_schema: object | None = None, / + ) -> object: + """Export the table as an Arrow stream via the PyCapsule Protocol.""" + ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... diff --git a/python/qroissant/py.typed b/python/qroissant/py.typed new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/python/qroissant/py.typed @@ -0,0 +1 @@ +