[fix](trx-client): harden json tcp parsing and io limits
Add typed remote endpoint parsing (including bracketed IPv6), bounded JSON line reads, and read/write/request timeouts across client/server JSON-TCP paths. Co-authored-by: OpenAI Codex <codex@openai.com> Signed-off-by: Stanislaw Grams <stanislawgrams@gmail.com>
This commit is contained in:
@@ -164,7 +164,7 @@ async fn async_init() -> DynResult<AppState> {
|
|||||||
.or_else(|| cfg.remote.url.clone())
|
.or_else(|| cfg.remote.url.clone())
|
||||||
.ok_or("Remote URL not specified. Use --url or set [remote].url in config.")?;
|
.ok_or("Remote URL not specified. Use --url or set [remote].url in config.")?;
|
||||||
|
|
||||||
let remote_addr =
|
let remote_endpoint =
|
||||||
parse_remote_url(&remote_url).map_err(|e| format!("Invalid remote URL: {}", e))?;
|
parse_remote_url(&remote_url).map_err(|e| format!("Invalid remote URL: {}", e))?;
|
||||||
|
|
||||||
let remote_token = cli.token.clone().or_else(|| cfg.remote.auth.token.clone());
|
let remote_token = cli.token.clone().or_else(|| cfg.remote.auth.token.clone());
|
||||||
@@ -216,7 +216,7 @@ async fn async_init() -> DynResult<AppState> {
|
|||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Starting trx-client (remote: {}, frontends: {})",
|
"Starting trx-client (remote: {}, frontends: {})",
|
||||||
remote_addr,
|
remote_endpoint.connect_addr(),
|
||||||
frontends.join(", ")
|
frontends.join(", ")
|
||||||
);
|
);
|
||||||
|
|
||||||
@@ -228,14 +228,10 @@ async fn async_init() -> DynResult<AppState> {
|
|||||||
let (state_tx, state_rx) = watch::channel(initial_state);
|
let (state_tx, state_rx) = watch::channel(initial_state);
|
||||||
|
|
||||||
// Extract host for audio before moving remote_addr
|
// Extract host for audio before moving remote_addr
|
||||||
let remote_host = remote_addr
|
let remote_host = remote_endpoint.host.clone();
|
||||||
.split(':')
|
|
||||||
.next()
|
|
||||||
.unwrap_or("127.0.0.1")
|
|
||||||
.to_string();
|
|
||||||
|
|
||||||
let remote_cfg = RemoteClientConfig {
|
let remote_cfg = RemoteClientConfig {
|
||||||
addr: remote_addr,
|
addr: remote_endpoint.connect_addr(),
|
||||||
token: remote_token,
|
token: remote_token,
|
||||||
poll_interval: Duration::from_millis(poll_interval_ms),
|
poll_interval: Duration::from_millis(poll_interval_ms),
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::{mpsc, watch};
|
use tokio::sync::{mpsc, watch};
|
||||||
use tokio::time::{self, Instant};
|
use tokio::time::{self, Instant};
|
||||||
@@ -16,6 +16,27 @@ use trx_core::{RigError, RigResult};
|
|||||||
use trx_protocol::rig_command_to_client;
|
use trx_protocol::rig_command_to_client;
|
||||||
use trx_protocol::{ClientCommand, ClientEnvelope, ClientResponse};
|
use trx_protocol::{ClientCommand, ClientEnvelope, ClientResponse};
|
||||||
|
|
||||||
|
const DEFAULT_REMOTE_PORT: u16 = 4532;
|
||||||
|
const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
|
||||||
|
const IO_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
const MAX_JSON_LINE_BYTES: usize = 16 * 1024;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, PartialEq, Eq)]
|
||||||
|
pub struct RemoteEndpoint {
|
||||||
|
pub host: String,
|
||||||
|
pub port: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RemoteEndpoint {
|
||||||
|
pub fn connect_addr(&self) -> String {
|
||||||
|
if self.host.contains(':') && !self.host.starts_with('[') {
|
||||||
|
format!("[{}]:{}", self.host, self.port)
|
||||||
|
} else {
|
||||||
|
format!("{}:{}", self.host, self.port)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct RemoteClientConfig {
|
pub struct RemoteClientConfig {
|
||||||
pub addr: String,
|
pub addr: String,
|
||||||
pub token: Option<String>,
|
pub token: Option<String>,
|
||||||
@@ -37,17 +58,20 @@ pub async fn run_remote_client(
|
|||||||
}
|
}
|
||||||
|
|
||||||
info!("Remote client: connecting to {}", config.addr);
|
info!("Remote client: connecting to {}", config.addr);
|
||||||
match TcpStream::connect(&config.addr).await {
|
match time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&config.addr)).await {
|
||||||
Ok(stream) => {
|
Ok(Ok(stream)) => {
|
||||||
if let Err(e) =
|
if let Err(e) =
|
||||||
handle_connection(&config, stream, &mut rx, &state_tx, &mut shutdown_rx).await
|
handle_connection(&config, stream, &mut rx, &state_tx, &mut shutdown_rx).await
|
||||||
{
|
{
|
||||||
warn!("Remote connection dropped: {}", e);
|
warn!("Remote connection dropped: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Ok(Err(e)) => {
|
||||||
warn!("Remote connect failed: {}", e);
|
warn!("Remote connect failed: {}", e);
|
||||||
}
|
}
|
||||||
|
Err(_) => {
|
||||||
|
warn!("Remote connect timed out after {:?}", CONNECT_TIMEOUT);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
@@ -128,20 +152,23 @@ async fn send_command(
|
|||||||
let payload = serde_json::to_string(&envelope)
|
let payload = serde_json::to_string(&envelope)
|
||||||
.map_err(|e| RigError::communication(format!("JSON serialize failed: {e}")))?;
|
.map_err(|e| RigError::communication(format!("JSON serialize failed: {e}")))?;
|
||||||
|
|
||||||
writer
|
time::timeout(
|
||||||
.write_all(format!("{}\n", payload).as_bytes())
|
IO_TIMEOUT,
|
||||||
.await
|
writer.write_all(format!("{}\n", payload).as_bytes()),
|
||||||
.map_err(|e| RigError::communication(format!("write failed: {e}")))?;
|
)
|
||||||
writer
|
.await
|
||||||
.flush()
|
.map_err(|_| RigError::communication(format!("write timed out after {:?}", IO_TIMEOUT)))?
|
||||||
|
.map_err(|e| RigError::communication(format!("write failed: {e}")))?;
|
||||||
|
time::timeout(IO_TIMEOUT, writer.flush())
|
||||||
.await
|
.await
|
||||||
|
.map_err(|_| RigError::communication(format!("flush timed out after {:?}", IO_TIMEOUT)))?
|
||||||
.map_err(|e| RigError::communication(format!("flush failed: {e}")))?;
|
.map_err(|e| RigError::communication(format!("flush failed: {e}")))?;
|
||||||
|
|
||||||
let mut line = String::new();
|
let line = time::timeout(IO_TIMEOUT, read_limited_line(reader, MAX_JSON_LINE_BYTES))
|
||||||
reader
|
|
||||||
.read_line(&mut line)
|
|
||||||
.await
|
.await
|
||||||
|
.map_err(|_| RigError::communication(format!("read timed out after {:?}", IO_TIMEOUT)))?
|
||||||
.map_err(|e| RigError::communication(format!("read failed: {e}")))?;
|
.map_err(|e| RigError::communication(format!("read failed: {e}")))?;
|
||||||
|
let line = line.ok_or_else(|| RigError::communication("connection closed by remote"))?;
|
||||||
|
|
||||||
let resp: ClientResponse = serde_json::from_str(line.trim_end())
|
let resp: ClientResponse = serde_json::from_str(line.trim_end())
|
||||||
.map_err(|e| RigError::communication(format!("invalid response: {e}")))?;
|
.map_err(|e| RigError::communication(format!("invalid response: {e}")))?;
|
||||||
@@ -159,7 +186,59 @@ async fn send_command(
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn parse_remote_url(url: &str) -> Result<String, String> {
|
async fn read_limited_line<R: AsyncBufRead + Unpin>(
|
||||||
|
reader: &mut R,
|
||||||
|
max_bytes: usize,
|
||||||
|
) -> std::io::Result<Option<String>> {
|
||||||
|
let mut line = Vec::with_capacity(256);
|
||||||
|
loop {
|
||||||
|
let available = reader.fill_buf().await?;
|
||||||
|
if available.is_empty() {
|
||||||
|
if line.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
let text = String::from_utf8(line).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line is not valid UTF-8: {e}"),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
return Ok(Some(text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pos) = available.iter().position(|b| *b == b'\n') {
|
||||||
|
let chunk = &available[..=pos];
|
||||||
|
if line.len() + chunk.len() > max_bytes {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line exceeds maximum size of {max_bytes} bytes"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
line.extend_from_slice(chunk);
|
||||||
|
reader.consume(pos + 1);
|
||||||
|
let text = String::from_utf8(line).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line is not valid UTF-8: {e}"),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
return Ok(Some(text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.len() + available.len() > max_bytes {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line exceeds maximum size of {max_bytes} bytes"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
line.extend_from_slice(available);
|
||||||
|
let consumed = available.len();
|
||||||
|
reader.consume(consumed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn parse_remote_url(url: &str) -> Result<RemoteEndpoint, String> {
|
||||||
let trimmed = url.trim();
|
let trimmed = url.trim();
|
||||||
if trimmed.is_empty() {
|
if trimmed.is_empty() {
|
||||||
return Err("remote url is empty".into());
|
return Err("remote url is empty".into());
|
||||||
@@ -170,9 +249,107 @@ pub fn parse_remote_url(url: &str) -> Result<String, String> {
|
|||||||
.or_else(|| trimmed.strip_prefix("http-json://"))
|
.or_else(|| trimmed.strip_prefix("http-json://"))
|
||||||
.unwrap_or(trimmed);
|
.unwrap_or(trimmed);
|
||||||
|
|
||||||
if !addr.contains(':') {
|
parse_host_port(addr)
|
||||||
return Ok(format!("{}:4532", addr));
|
}
|
||||||
|
|
||||||
|
fn parse_host_port(input: &str) -> Result<RemoteEndpoint, String> {
|
||||||
|
if let Some(rest) = input.strip_prefix('[') {
|
||||||
|
let closing = rest
|
||||||
|
.find(']')
|
||||||
|
.ok_or("invalid remote url: missing closing ']' for IPv6 host")?;
|
||||||
|
let host = &rest[..closing];
|
||||||
|
let remainder = &rest[closing + 1..];
|
||||||
|
if host.is_empty() {
|
||||||
|
return Err("invalid remote url: host is empty".into());
|
||||||
|
}
|
||||||
|
let port = if remainder.is_empty() {
|
||||||
|
DEFAULT_REMOTE_PORT
|
||||||
|
} else if let Some(port_str) = remainder.strip_prefix(':') {
|
||||||
|
parse_port(port_str)?
|
||||||
|
} else {
|
||||||
|
return Err("invalid remote url: expected ':<port>' after ']'".into());
|
||||||
|
};
|
||||||
|
return Ok(RemoteEndpoint {
|
||||||
|
host: host.to_string(),
|
||||||
|
port,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(addr.to_string())
|
if input.contains(':') {
|
||||||
|
if input.matches(':').count() > 1 {
|
||||||
|
return Err("invalid remote url: IPv6 host must be bracketed like [::1]:4532".into());
|
||||||
|
}
|
||||||
|
let (host, port_str) = input
|
||||||
|
.rsplit_once(':')
|
||||||
|
.ok_or("invalid remote url: expected host:port")?;
|
||||||
|
if host.is_empty() {
|
||||||
|
return Err("invalid remote url: host is empty".into());
|
||||||
|
}
|
||||||
|
return Ok(RemoteEndpoint {
|
||||||
|
host: host.to_string(),
|
||||||
|
port: parse_port(port_str)?,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(RemoteEndpoint {
|
||||||
|
host: input.to_string(),
|
||||||
|
port: DEFAULT_REMOTE_PORT,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_port(port_str: &str) -> Result<u16, String> {
|
||||||
|
let port: u16 = port_str
|
||||||
|
.parse()
|
||||||
|
.map_err(|_| format!("invalid remote port: '{port_str}'"))?;
|
||||||
|
if port == 0 {
|
||||||
|
return Err("invalid remote port: 0".into());
|
||||||
|
}
|
||||||
|
Ok(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{parse_remote_url, RemoteEndpoint};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_host_default_port() {
|
||||||
|
let parsed = parse_remote_url("example.local").expect("must parse");
|
||||||
|
assert_eq!(
|
||||||
|
parsed,
|
||||||
|
RemoteEndpoint {
|
||||||
|
host: "example.local".to_string(),
|
||||||
|
port: 4532
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_ipv4_with_port() {
|
||||||
|
let parsed = parse_remote_url("tcp://127.0.0.1:9000").expect("must parse");
|
||||||
|
assert_eq!(
|
||||||
|
parsed,
|
||||||
|
RemoteEndpoint {
|
||||||
|
host: "127.0.0.1".to_string(),
|
||||||
|
port: 9000
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_bracketed_ipv6() {
|
||||||
|
let parsed = parse_remote_url("http-json://[::1]:7000").expect("must parse");
|
||||||
|
assert_eq!(
|
||||||
|
parsed,
|
||||||
|
RemoteEndpoint {
|
||||||
|
host: "::1".to_string(),
|
||||||
|
port: 7000
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reject_unbracketed_ipv6() {
|
||||||
|
let err = parse_remote_url("::1:7000").expect_err("must fail");
|
||||||
|
assert!(err.contains("must be bracketed"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,27 @@
|
|||||||
|
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
use tokio::time;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use trx_core::rig::request::RigRequest;
|
use trx_core::rig::request::RigRequest;
|
||||||
use trx_core::rig::state::RigState;
|
use trx_core::rig::state::RigState;
|
||||||
use trx_frontend::{FrontendSpawner, FrontendRuntimeContext};
|
use trx_frontend::{FrontendRuntimeContext, FrontendSpawner};
|
||||||
use trx_protocol::auth::{SimpleTokenValidator, TokenValidator};
|
use trx_protocol::auth::{SimpleTokenValidator, TokenValidator};
|
||||||
use trx_protocol::codec::parse_envelope;
|
use trx_protocol::codec::parse_envelope;
|
||||||
use trx_protocol::mapping;
|
use trx_protocol::mapping;
|
||||||
use trx_protocol::ClientResponse;
|
use trx_protocol::ClientResponse;
|
||||||
|
|
||||||
|
const IO_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
const REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
|
||||||
|
const MAX_JSON_LINE_BYTES: usize = 16 * 1024;
|
||||||
|
|
||||||
/// JSON-over-TCP frontend for control and status.
|
/// JSON-over-TCP frontend for control and status.
|
||||||
pub struct HttpJsonFrontend;
|
pub struct HttpJsonFrontend;
|
||||||
|
|
||||||
@@ -68,17 +74,24 @@ async fn handle_client(
|
|||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
let (reader, mut writer) = socket.into_split();
|
let (reader, mut writer) = socket.into_split();
|
||||||
let mut reader = BufReader::new(reader);
|
let mut reader = BufReader::new(reader);
|
||||||
let mut line = String::new();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
line.clear();
|
let line = time::timeout(
|
||||||
let bytes_read = reader.read_line(&mut line).await?;
|
IO_TIMEOUT,
|
||||||
if bytes_read == 0 {
|
read_limited_line(&mut reader, MAX_JSON_LINE_BYTES),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::TimedOut,
|
||||||
|
"read timeout waiting for client request",
|
||||||
|
)
|
||||||
|
})??;
|
||||||
|
let Some(line) = line else {
|
||||||
info!("json tcp client {} disconnected", addr);
|
info!("json tcp client {} disconnected", addr);
|
||||||
break;
|
break;
|
||||||
}
|
};
|
||||||
|
|
||||||
// Simple protocol: one line = one JSON command.
|
|
||||||
let trimmed = line.trim();
|
let trimmed = line.trim();
|
||||||
if trimmed.is_empty() {
|
if trimmed.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
@@ -93,9 +106,7 @@ async fn handle_client(
|
|||||||
state: None,
|
state: None,
|
||||||
error: Some(format!("Invalid JSON: {}", e)),
|
error: Some(format!("Invalid JSON: {}", e)),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -106,13 +117,10 @@ async fn handle_client(
|
|||||||
state: None,
|
state: None,
|
||||||
error: Some(err),
|
error: Some(err),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Map ClientCommand -> RigCommand using trx-protocol.
|
|
||||||
let rig_cmd = mapping::client_command_to_rig(envelope.cmd);
|
let rig_cmd = mapping::client_command_to_rig(envelope.cmd);
|
||||||
|
|
||||||
let (resp_tx, resp_rx) = oneshot::channel();
|
let (resp_tx, resp_rx) = oneshot::channel();
|
||||||
@@ -121,50 +129,62 @@ async fn handle_client(
|
|||||||
respond_to: resp_tx,
|
respond_to: resp_tx,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = tx.send(req).await {
|
match time::timeout(IO_TIMEOUT, tx.send(req)).await {
|
||||||
error!("Failed to send request to rig_task: {:?}", e);
|
Ok(Ok(())) => {}
|
||||||
let resp = ClientResponse {
|
Ok(Err(e)) => {
|
||||||
success: false,
|
error!("Failed to send request to rig_task: {:?}", e);
|
||||||
state: None,
|
let resp = ClientResponse {
|
||||||
error: Some("Internal error: rig task not available".into()),
|
success: false,
|
||||||
};
|
state: None,
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
error: Some("Internal error: rig task not available".into()),
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
};
|
||||||
writer.flush().await?;
|
send_response(&mut writer, &resp).await?;
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let resp = ClientResponse {
|
||||||
|
success: false,
|
||||||
|
state: None,
|
||||||
|
error: Some("Internal error: request queue timeout".into()),
|
||||||
|
};
|
||||||
|
send_response(&mut writer, &resp).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match resp_rx.await {
|
match time::timeout(REQUEST_TIMEOUT, resp_rx).await {
|
||||||
Ok(Ok(snapshot)) => {
|
Ok(Ok(Ok(snapshot))) => {
|
||||||
let resp = ClientResponse {
|
let resp = ClientResponse {
|
||||||
success: true,
|
success: true,
|
||||||
state: Some(snapshot),
|
state: Some(snapshot),
|
||||||
error: None,
|
error: None,
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
}
|
}
|
||||||
Ok(Err(err)) => {
|
Ok(Ok(Err(err))) => {
|
||||||
let resp = ClientResponse {
|
let resp = ClientResponse {
|
||||||
success: false,
|
success: false,
|
||||||
state: None,
|
state: None,
|
||||||
error: Some(err.message),
|
error: Some(err.message),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Ok(Err(e)) => {
|
||||||
error!("Rig response oneshot recv error: {:?}", e);
|
error!("Rig response oneshot recv error: {:?}", e);
|
||||||
let resp = ClientResponse {
|
let resp = ClientResponse {
|
||||||
success: false,
|
success: false,
|
||||||
state: None,
|
state: None,
|
||||||
error: Some("Internal error waiting for rig response".into()),
|
error: Some("Internal error waiting for rig response".into()),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
}
|
||||||
writer.flush().await?;
|
Err(_) => {
|
||||||
|
let resp = ClientResponse {
|
||||||
|
success: false,
|
||||||
|
state: None,
|
||||||
|
error: Some("Request timed out waiting for rig response".into()),
|
||||||
|
};
|
||||||
|
send_response(&mut writer, &resp).await?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -172,6 +192,76 @@ async fn handle_client(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn read_limited_line<R: AsyncBufRead + Unpin>(
|
||||||
|
reader: &mut R,
|
||||||
|
max_bytes: usize,
|
||||||
|
) -> std::io::Result<Option<String>> {
|
||||||
|
let mut line = Vec::with_capacity(256);
|
||||||
|
loop {
|
||||||
|
let available = reader.fill_buf().await?;
|
||||||
|
if available.is_empty() {
|
||||||
|
if line.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
let text = String::from_utf8(line).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line is not valid UTF-8: {e}"),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
return Ok(Some(text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pos) = available.iter().position(|b| *b == b'\n') {
|
||||||
|
let chunk = &available[..=pos];
|
||||||
|
if line.len() + chunk.len() > max_bytes {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line exceeds maximum size of {max_bytes} bytes"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
line.extend_from_slice(chunk);
|
||||||
|
reader.consume(pos + 1);
|
||||||
|
let text = String::from_utf8(line).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line is not valid UTF-8: {e}"),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
return Ok(Some(text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.len() + available.len() > max_bytes {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line exceeds maximum size of {max_bytes} bytes"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
line.extend_from_slice(available);
|
||||||
|
let consumed = available.len();
|
||||||
|
reader.consume(consumed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_response(
|
||||||
|
writer: &mut tokio::net::tcp::OwnedWriteHalf,
|
||||||
|
response: &ClientResponse,
|
||||||
|
) -> std::io::Result<()> {
|
||||||
|
let resp_line = serde_json::to_string(response).map_err(std::io::Error::other)? + "\n";
|
||||||
|
time::timeout(IO_TIMEOUT, writer.write_all(resp_line.as_bytes()))
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::TimedOut, "response write timeout")
|
||||||
|
})??;
|
||||||
|
time::timeout(IO_TIMEOUT, writer.flush())
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::TimedOut, "response flush timeout")
|
||||||
|
})??;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn authorize(token: &Option<String>, context: &FrontendRuntimeContext) -> Result<(), String> {
|
fn authorize(token: &Option<String>, context: &FrontendRuntimeContext) -> Result<(), String> {
|
||||||
let validator = SimpleTokenValidator::new(context.auth_tokens.clone());
|
let validator = SimpleTokenValidator::new(context.auth_tokens.clone());
|
||||||
validator.validate(token)
|
validator.validate(token)
|
||||||
|
|||||||
+133
-37
@@ -10,10 +10,12 @@
|
|||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::sync::{mpsc, oneshot, watch};
|
use tokio::sync::{mpsc, oneshot, watch};
|
||||||
|
use tokio::time;
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
|
|
||||||
use trx_core::rig::command::RigCommand;
|
use trx_core::rig::command::RigCommand;
|
||||||
@@ -24,6 +26,10 @@ use trx_protocol::codec::parse_envelope;
|
|||||||
use trx_protocol::mapping;
|
use trx_protocol::mapping;
|
||||||
use trx_protocol::ClientResponse;
|
use trx_protocol::ClientResponse;
|
||||||
|
|
||||||
|
const IO_TIMEOUT: Duration = Duration::from_secs(10);
|
||||||
|
const REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
|
||||||
|
const MAX_JSON_LINE_BYTES: usize = 16 * 1024;
|
||||||
|
|
||||||
/// Run the JSON TCP listener, accepting client connections.
|
/// Run the JSON TCP listener, accepting client connections.
|
||||||
pub async fn run_listener(
|
pub async fn run_listener(
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
@@ -68,6 +74,76 @@ pub async fn run_listener(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn read_limited_line<R: AsyncBufRead + Unpin>(
|
||||||
|
reader: &mut R,
|
||||||
|
max_bytes: usize,
|
||||||
|
) -> std::io::Result<Option<String>> {
|
||||||
|
let mut line = Vec::with_capacity(256);
|
||||||
|
loop {
|
||||||
|
let available = reader.fill_buf().await?;
|
||||||
|
if available.is_empty() {
|
||||||
|
if line.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
let text = String::from_utf8(line).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line is not valid UTF-8: {e}"),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
return Ok(Some(text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(pos) = available.iter().position(|b| *b == b'\n') {
|
||||||
|
let chunk = &available[..=pos];
|
||||||
|
if line.len() + chunk.len() > max_bytes {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line exceeds maximum size of {max_bytes} bytes"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
line.extend_from_slice(chunk);
|
||||||
|
reader.consume(pos + 1);
|
||||||
|
let text = String::from_utf8(line).map_err(|e| {
|
||||||
|
std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line is not valid UTF-8: {e}"),
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
return Ok(Some(text));
|
||||||
|
}
|
||||||
|
|
||||||
|
if line.len() + available.len() > max_bytes {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::InvalidData,
|
||||||
|
format!("line exceeds maximum size of {max_bytes} bytes"),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
line.extend_from_slice(available);
|
||||||
|
let consumed = available.len();
|
||||||
|
reader.consume(consumed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_response(
|
||||||
|
writer: &mut tokio::net::tcp::OwnedWriteHalf,
|
||||||
|
response: &ClientResponse,
|
||||||
|
) -> std::io::Result<()> {
|
||||||
|
let resp_line = serde_json::to_string(response).map_err(std::io::Error::other)? + "\n";
|
||||||
|
time::timeout(IO_TIMEOUT, writer.write_all(resp_line.as_bytes()))
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::TimedOut, "response write timeout")
|
||||||
|
})??;
|
||||||
|
time::timeout(IO_TIMEOUT, writer.flush())
|
||||||
|
.await
|
||||||
|
.map_err(|_| {
|
||||||
|
std::io::Error::new(std::io::ErrorKind::TimedOut, "response flush timeout")
|
||||||
|
})??;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
async fn handle_client(
|
async fn handle_client(
|
||||||
socket: TcpStream,
|
socket: TcpStream,
|
||||||
addr: SocketAddr,
|
addr: SocketAddr,
|
||||||
@@ -78,12 +154,21 @@ async fn handle_client(
|
|||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
let (reader, mut writer) = socket.into_split();
|
let (reader, mut writer) = socket.into_split();
|
||||||
let mut reader = BufReader::new(reader);
|
let mut reader = BufReader::new(reader);
|
||||||
let mut line = String::new();
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
line.clear();
|
let line = tokio::select! {
|
||||||
let bytes_read = tokio::select! {
|
read = time::timeout(IO_TIMEOUT, read_limited_line(&mut reader, MAX_JSON_LINE_BYTES)) => {
|
||||||
read = reader.read_line(&mut line) => read?,
|
match read {
|
||||||
|
Ok(Ok(line)) => line,
|
||||||
|
Ok(Err(e)) => return Err(e),
|
||||||
|
Err(_) => {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::TimedOut,
|
||||||
|
"read timeout waiting for client request",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
changed = shutdown_rx.changed() => {
|
changed = shutdown_rx.changed() => {
|
||||||
match changed {
|
match changed {
|
||||||
Ok(()) if *shutdown_rx.borrow() => {
|
Ok(()) if *shutdown_rx.borrow() => {
|
||||||
@@ -95,10 +180,10 @@ async fn handle_client(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
if bytes_read == 0 {
|
let Some(line) = line else {
|
||||||
info!("Client {} disconnected", addr);
|
info!("Client {} disconnected", addr);
|
||||||
break;
|
break;
|
||||||
}
|
};
|
||||||
|
|
||||||
let trimmed = line.trim();
|
let trimmed = line.trim();
|
||||||
if trimmed.is_empty() {
|
if trimmed.is_empty() {
|
||||||
@@ -114,9 +199,7 @@ async fn handle_client(
|
|||||||
state: None,
|
state: None,
|
||||||
error: Some(format!("Invalid JSON: {}", e)),
|
error: Some(format!("Invalid JSON: {}", e)),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@@ -127,9 +210,7 @@ async fn handle_client(
|
|||||||
state: None,
|
state: None,
|
||||||
error: Some(err),
|
error: Some(err),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,9 +226,7 @@ async fn handle_client(
|
|||||||
state: Some(snapshot),
|
state: Some(snapshot),
|
||||||
error: None,
|
error: None,
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -158,21 +237,44 @@ async fn handle_client(
|
|||||||
respond_to: resp_tx,
|
respond_to: resp_tx,
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Err(e) = tx.send(req).await {
|
match time::timeout(IO_TIMEOUT, tx.send(req)).await {
|
||||||
error!("Failed to send request to rig_task: {:?}", e);
|
Ok(Ok(())) => {}
|
||||||
let resp = ClientResponse {
|
Ok(Err(e)) => {
|
||||||
success: false,
|
error!("Failed to send request to rig_task: {:?}", e);
|
||||||
state: None,
|
let resp = ClientResponse {
|
||||||
error: Some("Internal error: rig task not available".into()),
|
success: false,
|
||||||
};
|
state: None,
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
error: Some("Internal error: rig task not available".into()),
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
};
|
||||||
writer.flush().await?;
|
send_response(&mut writer, &resp).await?;
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
|
Err(_) => {
|
||||||
|
let resp = ClientResponse {
|
||||||
|
success: false,
|
||||||
|
state: None,
|
||||||
|
error: Some("Internal error: request queue timeout".into()),
|
||||||
|
};
|
||||||
|
send_response(&mut writer, &resp).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
match tokio::select! {
|
match tokio::select! {
|
||||||
result = resp_rx => result,
|
result = time::timeout(REQUEST_TIMEOUT, resp_rx) => {
|
||||||
|
match result {
|
||||||
|
Ok(inner) => inner,
|
||||||
|
Err(_) => {
|
||||||
|
let resp = ClientResponse {
|
||||||
|
success: false,
|
||||||
|
state: None,
|
||||||
|
error: Some("Request timed out waiting for rig response".into()),
|
||||||
|
};
|
||||||
|
send_response(&mut writer, &resp).await?;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
changed = shutdown_rx.changed() => {
|
changed = shutdown_rx.changed() => {
|
||||||
match changed {
|
match changed {
|
||||||
Ok(()) if *shutdown_rx.borrow() => {
|
Ok(()) if *shutdown_rx.borrow() => {
|
||||||
@@ -190,9 +292,7 @@ async fn handle_client(
|
|||||||
state: Some(snapshot),
|
state: Some(snapshot),
|
||||||
error: None,
|
error: None,
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
}
|
}
|
||||||
Ok(Err(err)) => {
|
Ok(Err(err)) => {
|
||||||
let resp = ClientResponse {
|
let resp = ClientResponse {
|
||||||
@@ -200,9 +300,7 @@ async fn handle_client(
|
|||||||
state: None,
|
state: None,
|
||||||
error: Some(err.message),
|
error: Some(err.message),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Rig response oneshot recv error: {:?}", e);
|
error!("Rig response oneshot recv error: {:?}", e);
|
||||||
@@ -211,9 +309,7 @@ async fn handle_client(
|
|||||||
state: None,
|
state: None,
|
||||||
error: Some("Internal error waiting for rig response".into()),
|
error: Some("Internal error waiting for rig response".into()),
|
||||||
};
|
};
|
||||||
let resp_line = serde_json::to_string(&resp)? + "\n";
|
send_response(&mut writer, &resp).await?;
|
||||||
writer.write_all(resp_line.as_bytes()).await?;
|
|
||||||
writer.flush().await?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user