diff --git a/src/trx-server/Cargo.toml b/src/trx-server/Cargo.toml index 5985233..bd0dfb7 100644 --- a/src/trx-server/Cargo.toml +++ b/src/trx-server/Cargo.toml @@ -24,3 +24,4 @@ opus = "0.3" trx-backend = { path = "trx-backend" } trx-core = { path = "../trx-core" } trx-ft8 = { path = "../trx-ft8" } +trx-protocol = { path = "../trx-protocol" } diff --git a/src/trx-server/src/listener.rs b/src/trx-server/src/listener.rs index 3fb449d..5e760fc 100644 --- a/src/trx-server/src/listener.rs +++ b/src/trx-server/src/listener.rs @@ -9,18 +9,21 @@ use std::collections::HashSet; use std::net::SocketAddr; +use std::sync::Arc; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, oneshot, watch}; use tracing::{error, info}; -use trx_core::client::ClientEnvelope; -use trx_core::radio::freq::Freq; use trx_core::rig::command::RigCommand; use trx_core::rig::request::RigRequest; -use trx_core::rig::state::{RigMode, RigState}; -use trx_core::{ClientCommand, ClientResponse}; +use trx_core::rig::state::RigState; +use trx_core::ClientResponse; + +use trx_protocol::codec::parse_envelope; +use trx_protocol::auth::{SimpleTokenValidator, TokenValidator}; +use trx_protocol::mapping; /// Run the JSON TCP listener, accepting client connections. pub async fn run_listener( @@ -32,15 +35,17 @@ pub async fn run_listener( let listener = TcpListener::bind(addr).await?; info!("Listening on {}", addr); + let validator = Arc::new(SimpleTokenValidator::new(auth_tokens)); + loop { let (socket, peer) = listener.accept().await?; info!("Client connected: {}", peer); let tx = rig_tx.clone(); - let tokens = auth_tokens.clone(); let srx = state_rx.clone(); + let validator = Arc::clone(&validator); tokio::spawn(async move { - if let Err(e) = handle_client(socket, peer, tx, &tokens, srx).await { + if let Err(e) = handle_client(socket, peer, tx, validator, srx).await { error!("Client {} error: {:?}", peer, e); } }); @@ -51,7 +56,7 @@ async fn handle_client( socket: TcpStream, addr: SocketAddr, tx: mpsc::Sender, - auth_tokens: &HashSet, + validator: Arc, state_rx: watch::Receiver, ) -> std::io::Result<()> { let (reader, mut writer) = socket.into_split(); @@ -87,7 +92,7 @@ async fn handle_client( } }; - if let Err(err) = authorize(&envelope.token, auth_tokens) { + if let Err(err) = validator.as_ref().validate(&envelope.token) { let resp = ClientResponse { success: false, state: None, @@ -99,7 +104,7 @@ async fn handle_client( continue; } - let rig_cmd = map_command(envelope.cmd); + let rig_cmd = mapping::client_command_to_rig(envelope.cmd); // Fast path: serve GetSnapshot directly from the watch channel // so clients get a response even while the rig task is initializing. @@ -175,78 +180,3 @@ async fn handle_client( Ok(()) } -fn map_command(cmd: ClientCommand) -> RigCommand { - match cmd { - ClientCommand::GetState => RigCommand::GetSnapshot, - ClientCommand::SetFreq { freq_hz } => RigCommand::SetFreq(Freq { hz: freq_hz }), - ClientCommand::SetMode { mode } => RigCommand::SetMode(parse_mode(&mode)), - ClientCommand::SetPtt { ptt } => RigCommand::SetPtt(ptt), - ClientCommand::PowerOn => RigCommand::PowerOn, - ClientCommand::PowerOff => RigCommand::PowerOff, - ClientCommand::ToggleVfo => RigCommand::ToggleVfo, - ClientCommand::Lock => RigCommand::Lock, - ClientCommand::Unlock => RigCommand::Unlock, - ClientCommand::GetTxLimit => RigCommand::GetTxLimit, - ClientCommand::SetTxLimit { limit } => RigCommand::SetTxLimit(limit), - ClientCommand::SetAprsDecodeEnabled { enabled } => RigCommand::SetAprsDecodeEnabled(enabled), - ClientCommand::SetCwDecodeEnabled { enabled } => RigCommand::SetCwDecodeEnabled(enabled), - ClientCommand::SetCwAuto { enabled } => RigCommand::SetCwAuto(enabled), - ClientCommand::SetCwWpm { wpm } => RigCommand::SetCwWpm(wpm), - ClientCommand::SetCwToneHz { tone_hz } => RigCommand::SetCwToneHz(tone_hz), - ClientCommand::SetFt8DecodeEnabled { enabled } => RigCommand::SetFt8DecodeEnabled(enabled), - ClientCommand::ResetAprsDecoder => RigCommand::ResetAprsDecoder, - ClientCommand::ResetCwDecoder => RigCommand::ResetCwDecoder, - ClientCommand::ResetFt8Decoder => RigCommand::ResetFt8Decoder, - } -} - -fn parse_mode(s: &str) -> RigMode { - match s.to_uppercase().as_str() { - "LSB" => RigMode::LSB, - "USB" => RigMode::USB, - "CW" => RigMode::CW, - "CWR" => RigMode::CWR, - "AM" => RigMode::AM, - "FM" => RigMode::FM, - "DIG" | "DIGI" => RigMode::DIG, - "PKT" | "PACKET" => RigMode::PKT, - other => RigMode::Other(other.to_string()), - } -} - -fn parse_envelope(input: &str) -> Result { - match serde_json::from_str::(input) { - Ok(envelope) => Ok(envelope), - Err(_) => { - let cmd = serde_json::from_str::(input)?; - Ok(ClientEnvelope { token: None, cmd }) - } - } -} - -fn authorize(token: &Option, valid_tokens: &HashSet) -> Result<(), String> { - if valid_tokens.is_empty() { - return Ok(()); - } - - let Some(token) = token.as_ref() else { - return Err("missing authorization token".into()); - }; - - let candidate = strip_bearer(token); - if valid_tokens.contains(candidate) { - return Ok(()); - } - - Err("invalid authorization token".into()) -} - -fn strip_bearer(value: &str) -> &str { - let trimmed = value.trim(); - let prefix = "bearer "; - if trimmed.len() >= prefix.len() && trimmed[..prefix.len()].eq_ignore_ascii_case(prefix) { - &trimmed[prefix.len()..] - } else { - trimmed - } -}