From 095f715a8045933159cb7daf3302246b9ca50658 Mon Sep 17 00:00:00 2001 From: edef Date: Tue, 30 Apr 2024 07:15:37 +0000 Subject: refactor(nix-compat/wire): drop primitive functions These may as well be inlined, and hardly need tests, since they just alias AsyncReadExt::read_u64_le / AsyncWriteExt::write_u64_le. Boolean reading is worth making explicit, since callers may differ on how they want to handle values other than 0 and 1. Boolean writing simplifies to `.write_u64_le(x as u64)`, which is also fine to inline. Change-Id: Ief9722fe886688693feb924ff0306b5bc68dd7a2 Reviewed-on: https://cl.tvl.fyi/c/depot/+/11549 Reviewed-by: flokli Tested-by: BuildkiteCI --- tvix/nix-compat/src/nix_daemon/worker_protocol.rs | 46 +++++++------- tvix/nix-compat/src/wire/bytes/mod.rs | 8 +-- tvix/nix-compat/src/wire/bytes/reader/mod.rs | 6 +- tvix/nix-compat/src/wire/mod.rs | 3 - tvix/nix-compat/src/wire/primitive.rs | 74 ----------------------- users/picnoir/tvix-daemon/src/main.rs | 8 ++- 6 files changed, 33 insertions(+), 112 deletions(-) delete mode 100644 tvix/nix-compat/src/wire/primitive.rs diff --git a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs index 58a48d1bd..9ffceffce 100644 --- a/tvix/nix-compat/src/nix_daemon/worker_protocol.rs +++ b/tvix/nix-compat/src/nix_daemon/worker_protocol.rs @@ -131,27 +131,27 @@ pub async fn read_client_settings( r: &mut R, client_version: ProtocolVersion, ) -> std::io::Result { - let keep_failed = wire::read_bool(r).await?; - let keep_going = wire::read_bool(r).await?; - let try_fallback = wire::read_bool(r).await?; - let verbosity_uint = wire::read_u64(r).await?; + let keep_failed = r.read_u64_le().await? != 0; + let keep_going = r.read_u64_le().await? != 0; + let try_fallback = r.read_u64_le().await? != 0; + let verbosity_uint = r.read_u64_le().await?; let verbosity = Verbosity::from_u64(verbosity_uint).ok_or_else(|| { Error::new( ErrorKind::InvalidData, format!("Can't convert integer {} to verbosity", verbosity_uint), ) })?; - let max_build_jobs = wire::read_u64(r).await?; - let max_silent_time = wire::read_u64(r).await?; - _ = wire::read_u64(r).await?; // obsolete useBuildHook - let verbose_build = wire::read_bool(r).await?; - _ = wire::read_u64(r).await?; // obsolete logType - _ = wire::read_u64(r).await?; // obsolete printBuildTrace - let build_cores = wire::read_u64(r).await?; - let use_substitutes = wire::read_bool(r).await?; + let max_build_jobs = r.read_u64_le().await?; + let max_silent_time = r.read_u64_le().await?; + _ = r.read_u64_le().await?; // obsolete useBuildHook + let verbose_build = r.read_u64_le().await? != 0; + _ = r.read_u64_le().await?; // obsolete logType + _ = r.read_u64_le().await?; // obsolete printBuildTrace + let build_cores = r.read_u64_le().await?; + let use_substitutes = r.read_u64_le().await? != 0; let mut overrides = HashMap::new(); if client_version.minor() >= 12 { - let num_overrides = wire::read_u64(r).await?; + let num_overrides = r.read_u64_le().await?; for _ in 0..num_overrides { let name = wire::read_string(r, 0..MAX_SETTING_SIZE).await?; let value = wire::read_string(r, 0..MAX_SETTING_SIZE).await?; @@ -197,17 +197,17 @@ pub async fn server_handshake_client<'a, RW: 'a>( where &'a mut RW: AsyncReadExt + AsyncWriteExt + Unpin, { - let worker_magic_1 = wire::read_u64(&mut conn).await?; + let worker_magic_1 = conn.read_u64_le().await?; if worker_magic_1 != WORKER_MAGIC_1 { Err(std::io::Error::new( ErrorKind::InvalidData, format!("Incorrect worker magic number received: {}", worker_magic_1), )) } else { - wire::write_u64(&mut conn, WORKER_MAGIC_2).await?; - wire::write_u64(&mut conn, PROTOCOL_VERSION.into()).await?; + conn.write_u64_le(WORKER_MAGIC_2).await?; + conn.write_u64_le(PROTOCOL_VERSION.into()).await?; conn.flush().await?; - let client_version = wire::read_u64(&mut conn).await?; + let client_version = conn.read_u64_le().await?; // Parse into ProtocolVersion. let client_version: ProtocolVersion = client_version .try_into() @@ -220,14 +220,14 @@ where } if client_version.minor() >= 14 { // Obsolete CPU affinity. - let read_affinity = wire::read_u64(&mut conn).await?; + let read_affinity = conn.read_u64_le().await?; if read_affinity != 0 { - let _cpu_affinity = wire::read_u64(&mut conn).await?; + let _cpu_affinity = conn.read_u64_le().await?; }; } if client_version.minor() >= 11 { // Obsolete reserveSpace - let _reserve_space = wire::read_u64(&mut conn).await?; + let _reserve_space = conn.read_u64_le().await?; } if client_version.minor() >= 33 { // Nix version. We're plain lying, we're not Nix, but eh… @@ -245,7 +245,7 @@ where /// Read a worker [Operation] from the wire. pub async fn read_op(r: &mut R) -> std::io::Result { - let op_number = wire::read_u64(r).await?; + let op_number = r.read_u64_le().await?; Operation::from_u64(op_number).ok_or(Error::new( ErrorKind::InvalidData, format!("Invalid OP number {}", op_number), @@ -278,8 +278,8 @@ where W: AsyncReadExt + AsyncWriteExt + Unpin, { match t { - Trust::Trusted => wire::write_u64(conn, 1).await, - Trust::NotTrusted => wire::write_u64(conn, 2).await, + Trust::Trusted => conn.write_u64_le(1).await, + Trust::NotTrusted => conn.write_u64_le(2).await, } } diff --git a/tvix/nix-compat/src/wire/bytes/mod.rs b/tvix/nix-compat/src/wire/bytes/mod.rs index 031d969e2..fc777bafe 100644 --- a/tvix/nix-compat/src/wire/bytes/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/mod.rs @@ -9,8 +9,6 @@ pub use reader::BytesReader; mod writer; pub use writer::BytesWriter; -use super::primitive; - /// 8 null bytes, used to write out padding. const EMPTY_BYTES: &[u8; 8] = &[0u8; 8]; @@ -41,7 +39,7 @@ where S: RangeBounds, { // read the length field - let len = primitive::read_u64(r).await?; + let len = r.read_u64_le().await?; if !allowed_size.contains(&len) { return Err(std::io::Error::new( @@ -52,7 +50,7 @@ where // calculate the total length, including padding. // byte packets are padded to 8 byte blocks each. - let padded_len = padding_len(len) as u64 + (len as u64); + let padded_len = padding_len(len) as u64 + len; let mut limited_reader = r.take(padded_len); let mut buf = Vec::new(); @@ -105,7 +103,7 @@ pub async fn write_bytes>( b: B, ) -> std::io::Result<()> { // write the size packet. - primitive::write_u64(w, b.as_ref().len() as u64).await?; + w.write_u64_le(b.as_ref().len() as u64).await?; // write the payload w.write_all(b.as_ref()).await?; diff --git a/tvix/nix-compat/src/wire/bytes/reader/mod.rs b/tvix/nix-compat/src/wire/bytes/reader/mod.rs index ef59e9c16..50398d9b9 100644 --- a/tvix/nix-compat/src/wire/bytes/reader/mod.rs +++ b/tvix/nix-compat/src/wire/bytes/reader/mod.rs @@ -5,9 +5,7 @@ use std::{ pin::Pin, task::{self, ready, Poll}, }; -use tokio::io::{AsyncRead, ReadBuf}; - -use crate::wire::read_u64; +use tokio::io::{AsyncRead, AsyncReadExt, ReadBuf}; use trailer::{read_trailer, ReadTrailer, Trailer}; mod trailer; @@ -52,7 +50,7 @@ where { /// Constructs a new BytesReader, using the underlying passed reader. pub async fn new>(mut reader: R, allowed_size: S) -> io::Result { - let size = read_u64(&mut reader).await?; + let size = reader.read_u64_le().await?; if !allowed_size.contains(&size) { return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid size")); diff --git a/tvix/nix-compat/src/wire/mod.rs b/tvix/nix-compat/src/wire/mod.rs index 65c053d58..a197e3a1f 100644 --- a/tvix/nix-compat/src/wire/mod.rs +++ b/tvix/nix-compat/src/wire/mod.rs @@ -3,6 +3,3 @@ mod bytes; pub use bytes::*; - -mod primitive; -pub use primitive::*; diff --git a/tvix/nix-compat/src/wire/primitive.rs b/tvix/nix-compat/src/wire/primitive.rs deleted file mode 100644 index ee0f5fc42..000000000 --- a/tvix/nix-compat/src/wire/primitive.rs +++ /dev/null @@ -1,74 +0,0 @@ -// SPDX-FileCopyrightText: 2023 embr -// -// SPDX-License-Identifier: EUPL-1.2 - -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; - -#[allow(dead_code)] -/// Read a u64 from the AsyncRead (little endian). -pub async fn read_u64(r: &mut R) -> std::io::Result { - r.read_u64_le().await -} - -/// Write a u64 to the AsyncWrite (little endian). -pub async fn write_u64(w: &mut W, v: u64) -> std::io::Result<()> { - w.write_u64_le(v).await -} - -#[allow(dead_code)] -/// Read a boolean from the AsyncRead, encoded as u64 (>0 is true). -pub async fn read_bool(r: &mut R) -> std::io::Result { - Ok(read_u64(r).await? > 0) -} - -#[allow(dead_code)] -/// Write a boolean to the AsyncWrite, encoded as u64 (>0 is true). -pub async fn write_bool(w: &mut W, v: bool) -> std::io::Result<()> { - write_u64(w, if v { 1u64 } else { 0u64 }).await -} - -#[cfg(test)] -mod tests { - use super::*; - use tokio_test::io::Builder; - - // Integers. - #[tokio::test] - async fn test_read_u64() { - let mut mock = Builder::new().read(&1234567890u64.to_le_bytes()).build(); - assert_eq!(1234567890u64, read_u64(&mut mock).await.unwrap()); - } - #[tokio::test] - async fn test_write_u64() { - let mut mock = Builder::new().write(&1234567890u64.to_le_bytes()).build(); - write_u64(&mut mock, 1234567890).await.unwrap(); - } - - // Booleans. - #[tokio::test] - async fn test_read_bool_0() { - let mut mock = Builder::new().read(&0u64.to_le_bytes()).build(); - assert!(!read_bool(&mut mock).await.unwrap()); - } - #[tokio::test] - async fn test_read_bool_1() { - let mut mock = Builder::new().read(&1u64.to_le_bytes()).build(); - assert!(read_bool(&mut mock).await.unwrap()); - } - #[tokio::test] - async fn test_read_bool_2() { - let mut mock = Builder::new().read(&2u64.to_le_bytes()).build(); - assert!(read_bool(&mut mock).await.unwrap()); - } - - #[tokio::test] - async fn test_write_bool_false() { - let mut mock = Builder::new().write(&0u64.to_le_bytes()).build(); - write_bool(&mut mock, false).await.unwrap(); - } - #[tokio::test] - async fn test_write_bool_true() { - let mut mock = Builder::new().write(&1u64.to_le_bytes()).build(); - write_bool(&mut mock, true).await.unwrap(); - } -} diff --git a/users/picnoir/tvix-daemon/src/main.rs b/users/picnoir/tvix-daemon/src/main.rs index 102067fcf..dc49b209e 100644 --- a/users/picnoir/tvix-daemon/src/main.rs +++ b/users/picnoir/tvix-daemon/src/main.rs @@ -4,7 +4,7 @@ use tokio_listener::{self, SystemOptions, UserOptions}; use tracing::{debug, error, info, instrument, Level}; use nix_compat::worker_protocol::{self, server_handshake_client, ClientSettings, Trust}; -use nix_compat::{wire, ProtocolVersion}; +use nix_compat::ProtocolVersion; #[derive(Parser, Debug)] struct Cli { @@ -78,7 +78,9 @@ where // TODO: implement logging. For now, we'll just send // STDERR_LAST, which is good enough to get Nix respond to // us. - wire::write_u64(&mut client_connection.conn, worker_protocol::STDERR_LAST) + client_connection + .conn + .write_u64_le(worker_protocol::STDERR_LAST) .await .unwrap(); loop { @@ -109,6 +111,6 @@ where let settings = worker_protocol::read_client_settings(&mut conn.conn, conn.version).await?; // The client expects us to send some logs when we're processing // the settings. Sending STDERR_LAST signal we're done processing. - wire::write_u64(&mut conn.conn, worker_protocol::STDERR_LAST).await?; + conn.conn.write_u64_le(worker_protocol::STDERR_LAST).await?; Ok(settings) } -- cgit 1.4.1