From 14e84e6f1c5709e35f8030d4a860ef9a4319600d Mon Sep 17 00:00:00 2001 From: Jeremy Wall Date: Sat, 20 Feb 2021 18:22:28 -0500 Subject: [PATCH] Handle multiple ping targets --- src/icmp.rs | 656 ++++++++++++++++++++++++++++++---------------------- src/main.rs | 10 +- 2 files changed, 385 insertions(+), 281 deletions(-) diff --git a/src/icmp.rs b/src/icmp.rs index 58498ec..a294672 100644 --- a/src/icmp.rs +++ b/src/icmp.rs @@ -11,8 +11,11 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -use std::time::{Duration, Instant}; -use std::{convert::TryFrom, ops::Sub}; +use std::ops::Sub; +use std::{ + collections::HashMap, + time::{Duration, Instant}, +}; use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, sync::{Arc, RwLock}, @@ -25,9 +28,8 @@ use icmp_socket::{ packet::{Icmpv4Message, Icmpv6Message, WithEchoRequest}, IcmpSocket, IcmpSocket4, IcmpSocket6, Icmpv4Packet, Icmpv6Packet, }; -use log::{error, info}; +use log::{debug, error, info}; use prometheus::{CounterVec, GaugeVec}; -use socket2::{self, SockAddr}; gflags::define! { /// The payload to use for the ping requests. @@ -39,6 +41,11 @@ gflags::define! { --pingTimeout: u64 = 2048 } +gflags::define! { + /// The delay between ping requests. + --pingDelay: u64 = 5 +} + gflags::define! { /// The size in bytes of the ping requests. --maxHops: u8 = 50 @@ -55,310 +62,403 @@ fn resolve_host_address(host: &str) -> String { ) } -fn loop_impl( - mut socket: Sock, - dest: Sock::AddrType, - packet_handler: PH, - err_handler: EH, +struct State { + sequence: u16, + destinations: HashMap, // domain, address + time_tracker: HashMap, + latency_guage: GaugeVec, + ping_counter: CounterVec, stop_signal: Arc>, -) where - PH: Fn(Sock::PacketType, socket2::SockAddr, Instant, u16) -> Option<()>, - EH: Fn(std::io::Error, bool) -> (), +} + +struct PingerImpl { + sock: Sock, + timeout: Duration, +} + +trait PacketHandler +where + AddrType: std::fmt::Display + Copy, + PacketType: WithEchoRequest, +{ + fn get_mut_state(&mut self) -> &mut State; + fn handle_pkt(&mut self, pkt: PacketType) -> bool; +} + +impl<'a> PacketHandler for &'a mut State { + fn get_mut_state(&mut self) -> &mut State { + return self; + } + + fn handle_pkt(&mut self, pkt: Icmpv6Packet) -> bool { + match pkt.message { + Icmpv6Message::Unreachable { + _unused, + invoking_packet, + } => { + match Icmpv6Packet::parse(&invoking_packet) { + Ok(Icmpv6Packet { + typ: _, + code: _, + checksum: _, + message: + Icmpv6Message::EchoRequest { + identifier, + sequence: _, + payload: _, + }, + }) => { + if let Some((domain_name, _addr)) = self.destinations.get(&identifier) { + self.ping_counter + .with(&prometheus::labels! {"result" => "unreachable", "domain" => domain_name}) + .inc(); + return true; + } + } + Err(e) => { + // We ignore these as well but log it. + error!("ICMP: Error parsing Unreachable invoking packet {:?}", e); + } + _ => { + // We ignore these + } + }; + } + Icmpv6Message::ParameterProblem { + pointer: _, + invoking_packet, + } => { + match Icmpv6Packet::parse(&invoking_packet) { + Ok(Icmpv6Packet { + typ: _, + code: _, + checksum: _, + message: + Icmpv6Message::EchoRequest { + identifier, + sequence: _, + payload: _, + }, + }) => { + if let Some((domain_name, _addr)) = self.destinations.get(&identifier) { + self.ping_counter + .with(&prometheus::labels! {"result" => "parameter_problem", "domain" => domain_name}) + .inc(); + return true; + } + } + Err(e) => { + // We ignore these as well but log it. + error!("ICMP: Error parsing Unreachable invoking packet {:?}", e); + } + _ => { + // We ignore these + } + } + } + Icmpv6Message::EchoReply { + identifier, + sequence, + payload: _, + } => { + if let Some((domain_name, dest)) = self.destinations.get(&identifier) { + if self.sequence != sequence { + error!("ICMP: Discarding sequence {}", sequence); + return false; + } + let elapsed = if let Some(send_time) = self.time_tracker.get(&identifier) { + Instant::now().sub(send_time.clone()).as_micros() as f64 / 1000.00 + } else { + return false; + }; + info!( + "ICMP: Reply from {}({}): time={}ms, seq={}", + domain_name, dest, elapsed, sequence, + ); + self.ping_counter + .with(&prometheus::labels! {"result" => "ok", "domain" => domain_name}) + .inc(); + if elapsed as i32 != 0 { + self.latency_guage + .with(&prometheus::labels! {"domain" => domain_name.as_str()}) + .set(elapsed); + } + return true; + } else { + info!("ICMP: Discarding wrong identifier {}", identifier); + } + } + _ => { + // We ignore the rest. + } + } + return false; + } +} + +impl<'a> PacketHandler for &'a mut State { + fn get_mut_state(&mut self) -> &mut State { + return self; + } + + fn handle_pkt(&mut self, pkt: Icmpv4Packet) -> bool { + match pkt.message { + Icmpv4Message::ParameterProblem { + pointer: _, + padding: _, + header: _, + } => { + self.ping_counter + .with(&prometheus::labels! {"result" => "parameter_problem", "domain" => "unknown"}) + .inc(); + } + Icmpv4Message::Unreachable { padding: _, header } => { + let dest_addr = Ipv4Addr::new(header[16], header[17], header[18], header[19]); + info!("ICMP: Destination Unreachable response from {}", dest_addr,); + self.ping_counter + .with(&prometheus::labels! {"result" => "unreachable", "domain" => "unknown"}) + .inc(); + } + Icmpv4Message::TimeExceeded { padding: _, header } => { + let dest_addr = Ipv4Addr::new(header[16], header[17], header[18], header[19]); + info!("ICMP: Timeout for {}", dest_addr); + self.ping_counter + .with(&prometheus::labels! {"result" => "timeout", "domain" => "unknown"}) + .inc(); + } + Icmpv4Message::EchoReply { + identifier, + sequence, + payload: _, + } => { + if let Some((domain_name, dest)) = self.destinations.get(&identifier) { + let elapsed = if let Some(send_time) = self.time_tracker.get(&identifier) { + Instant::now().sub(send_time.clone()).as_micros() as f64 / 1000.00 + } else { + return false; + }; + if self.sequence != sequence { + error!( + "ICMP: Discarding sequence {}, expected sequence {}", + sequence, self.sequence + ); + return false; + } + info!( + "ICMP: Reply from {}({}): time={}ms, seq={}", + domain_name, dest, elapsed, sequence, + ); + self.ping_counter + .with(&prometheus::labels! {"result" => "ok", "domain" => domain_name}) + .inc(); + self.latency_guage + .with(&prometheus::labels! {"domain" => domain_name.as_str()}) + .set(elapsed); + return true; + } else { + info!("ICMP: Discarding wrong identifier {}", identifier); + } + } + p => { + // We ignore the rest. + info!("ICMP Unhandled packet {:?}", p); + } + } + return false; + } +} + +trait Pinger +where + AddrType: std::fmt::Display + Copy, + PacketType: WithEchoRequest, +{ + fn send_all(&mut self, state: &mut State) -> std::io::Result<()>; + fn send_to_destination( + &mut self, + dest: AddrType, + identifier: u16, + sequence: u16, + ) -> std::io::Result; + + fn recv_pkt(&mut self) -> std::io::Result; + fn recv_all>(&mut self, handler: H); +} + +impl Pinger for PingerImpl +where Sock: IcmpSocket, Sock::AddrType: std::fmt::Display + Copy, Sock::PacketType: WithEchoRequest, { - if let Err(e) = socket.set_timeout(Duration::from_secs(1)) { - error!( - "ICMP: Failed to set timeout on socket. Not starting thread. {:?}", - e - ); - return; - } - let mut sequence: u16 = 0; - loop { - { - // Limit the scope of this lock - if *stop_signal.read().unwrap() { - info!("Stopping ping thread for {}", dest); - return; + fn send_all(&mut self, state: &mut State) -> std::io::Result<()> { + self.sock.set_timeout(self.timeout)?; + let destinations = state.destinations.clone(); + for (identifier, (domain_name, dest)) in destinations.into_iter() { + debug!("ICMP: sending echo request to {}({})", domain_name, dest); + match self.send_to_destination(dest, identifier, state.sequence) { + Err(e) => { + state + .ping_counter + .with(&prometheus::labels! {"result" => "err", "type" => "send"}) + .inc(); + error!( + "ICMP: error sending to domain: {} and address: {} failed: {:?}, Trying again later", + domain_name, &dest, e + ); + } + Ok(send_time) => { + state.time_tracker.insert(identifier, send_time); + } + } + { + // Scope the lock really tightly + if *state.stop_signal.read().unwrap() { + return Ok(()); + } } } + Ok(()) + } + + fn send_to_destination( + &mut self, + dest: Sock::AddrType, + identifier: u16, + sequence: u16, + ) -> std::io::Result { let packet = Sock::PacketType::with_echo_request( - 42, + identifier, sequence, PINGPAYLOAD.flag.as_bytes().to_owned(), ) .unwrap(); let send_time = Instant::now(); - if let Err(e) = socket.send_to(dest, packet) { - err_handler(e, true); - } else { + self.sock.send_to(dest, packet)?; + Ok(send_time) + } + + fn recv_pkt(&mut self) -> std::io::Result { + let (response, _addr) = self.sock.rcv_from()?; + Ok(response) + } + + fn recv_all>(&mut self, mut handler: H) { + let expected_len = handler.get_mut_state().time_tracker.len(); + for _ in 0..expected_len { loop { - // Keep going until we get the packet we are looking for. - match socket.rcv_from() { - Err(e) => { - err_handler(e, false); - } - Ok((resp, sock_addr)) => { - if packet_handler(resp, sock_addr, send_time, sequence).is_some() { - sequence = sequence.wrapping_add(1); + // Receive loop + match self.recv_pkt() { + Ok(pkt) => { + if handler.handle_pkt(pkt) { + // break out of the recv loop break; } } + Err(e) => { + error!("Error receiving packet: {:?}", e); + handler + .get_mut_state() + .ping_counter + .with(&prometheus::labels! {"result" => "err", "domain" => "unknown"}) + .inc(); + } } - // Give up after 3 seconds and send another packet. - if Instant::now() - send_time > Duration::from_secs(3) { - break; + { + // Scope the lock really tightly. + if *handler.get_mut_state().stop_signal.read().unwrap() { + return; + } } } } - std::thread::sleep(Duration::from_secs(3)); + let mut state = handler.get_mut_state(); + state.sequence = state.sequence.wrapping_add(1); } } pub fn start_echo_loop( - domain_name: &str, + domain_names: &Vec<&str>, stop_signal: Arc>, ping_latency_guage: GaugeVec, ping_counter: CounterVec, ) { - let resolved = resolve_host_address(domain_name); - info!( - "Attempting to ping domain {} at address: {}", - domain_name, resolved - ); - let dest = resolved - .parse::() - .expect(&format!("Invalid IP Address {}", resolved)); + let resolved: Vec<(String, IpAddr)> = domain_names + .iter() + .map(|domain_name| { + let resolved = resolve_host_address(domain_name); + let dest = resolved + .parse::() + .expect(&format!("Invalid IP Address {}", resolved)); + (domain_name.to_string(), dest) + }) + .collect(); + let mut v4_targets: Vec<(String, Ipv4Addr)> = Vec::new(); + let mut v6_targets: Vec<(String, Ipv6Addr)> = Vec::new(); + for (name, addr) in resolved { + match addr { + IpAddr::V6(addr) => { + v6_targets.push((name, addr)); + } + IpAddr::V4(addr) => { + v4_targets.push((name, addr)); + } + } + } - let err_handler = |e: std::io::Error, send: bool| { - ping_counter - .with(&prometheus::labels! {"result" => "err", "domain" => domain_name}) - .inc(); - if send { - error!( - "ICMP: error sending to domain: {} and address: {} failed: {:?}, Trying again later", - domain_name, &dest, e - ); - } else { - error!( - "ICMP: error receiving for domain: {} and address: {} failed: {:?}, Trying again later", - domain_name, &dest, e - ); - } + let mut v4_destinations = HashMap::new(); + let mut v4_id_counter = 42; + for target in v4_targets { + info!("ICMP: Attempting ping to {}({})", target.0, target.1); + v4_destinations.insert(v4_id_counter, target.clone()); + v4_id_counter += 1; + } + let mut v4_state = State { + sequence: 0, + destinations: v4_destinations, + time_tracker: HashMap::new(), + latency_guage: ping_latency_guage.clone(), + ping_counter: ping_counter.clone(), + stop_signal: stop_signal.clone(), }; - match dest { - IpAddr::V4(dest) => { - let mut socket = IcmpSocket4::try_from(Ipv4Addr::new(0, 0, 0, 0)).unwrap(); - socket.set_max_hops(MAXHOPS.flag as u32); - let packet_handler = |p: Icmpv4Packet, - _s: SockAddr, - send_time: Instant, - seq: u16| - -> Option<()> { - // We only want to handle replies for the address we are pinging. - match p.message { - Icmpv4Message::ParameterProblem { - pointer: _, - padding: _, - header, - } => { - let dest_addr = - Ipv4Addr::new(header[16], header[17], header[18], header[19]); - if dest_addr == dest { - ping_counter - .with(&prometheus::labels! {"result" => "parameter_problem", "domain" => domain_name}) - .inc(); - } else { - return None; - } - } - Icmpv4Message::Unreachable { padding: _, header } => { - let dest_addr = - Ipv4Addr::new(header[16], header[17], header[18], header[19]); - if dest_addr == dest { - info!( - "ICMP: Destination: {:?} Unreachable {} response from {}", - dest_addr, - dest, - _s.as_inet().unwrap().ip() - ); - ping_counter - .with(&prometheus::labels! {"result" => "unreachable", "domain" => domain_name}) - .inc(); - } else { - return None; - } - } - Icmpv4Message::TimeExceeded { padding: _, header } => { - let dest_addr = - Ipv4Addr::new(header[16], header[17], header[18], header[19]); - if dest_addr == dest { - info!("ICMP: Timeout for {}", dest); - ping_counter - .with(&prometheus::labels! {"result" => "timeout", "domain" => domain_name}) - .inc(); - } else { - return None; - } - } - Icmpv4Message::EchoReply { - identifier, - sequence, - payload: _, - } => { - if identifier != 42 { - info!("ICMP: Discarding wrong identifier {}", identifier); - return None; - } - if sequence != seq { - error!( - "ICMP: Discarding sequence {}, expected sequence {}", - sequence, seq - ); - return None; - } - let elapsed = - Instant::now().sub(send_time.clone()).as_micros() as f64 / 1000.00; - info!( - "ICMP: Reply from {}: time={}ms, seq={}", - dest, elapsed, sequence, - ); - ping_counter - .with(&prometheus::labels! {"result" => "ok", "domain" => domain_name}) - .inc(); - if elapsed as i32 != 0 { - ping_latency_guage - .with(&prometheus::labels! {"domain" => domain_name}) - .set(elapsed); - } - } - p => { - // We ignore the rest. - info!("ICMP Unhandled packet {:?}", p); - return None; - } - } - Some(()) - }; - loop_impl(socket, dest, packet_handler, err_handler, stop_signal); - } - IpAddr::V6(dest) => { - let mut socket = IcmpSocket6::try_from(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0)).unwrap(); - socket.set_max_hops(MAXHOPS.flag as u32); - let packet_handler = |p: Icmpv6Packet, - _s: SockAddr, - send_time: Instant, - seq: u16| - -> Option<()> { - match p.message { - Icmpv6Message::Unreachable { - _unused, - invoking_packet, - } => { - match Icmpv6Packet::parse(&invoking_packet) { - Ok(Icmpv6Packet { - typ: _, - code: _, - checksum: _, - message: - Icmpv6Message::EchoRequest { - identifier, - sequence: _, - payload: _, - }, - }) => { - if identifier == 42 { - ping_counter - .with(&prometheus::labels! {"result" => "unreachable", "domain" => domain_name}) - .inc(); - return Some(()); - } - } - Err(e) => { - // We ignore these as well but log it. - error!("ICMP: Error parsing Unreachable invoking packet {:?}", e); - } - _ => { - // We ignore these - } - }; - return None; - } - Icmpv6Message::ParameterProblem { - pointer: _, - invoking_packet, - } => { - match Icmpv6Packet::parse(&invoking_packet) { - Ok(Icmpv6Packet { - typ: _, - code: _, - checksum: _, - message: - Icmpv6Message::EchoRequest { - identifier, - sequence: _, - payload: _, - }, - }) => { - if identifier == 42 { - ping_counter - .with(&prometheus::labels! {"result" => "parameter_problem", "domain" => domain_name}) - .inc(); - return Some(()); - } - } - Err(e) => { - // We ignore these as well but log it. - error!("ICMP: Error parsing Unreachable invoking packet {:?}", e); - } - _ => { - // We ignore these - } - } - return None; - } - Icmpv6Message::EchoReply { - identifier, - sequence, - payload: _, - } => { - if identifier != 42 { - info!("ICMP: Discarding wrong identifier {}", identifier); - return None; - } - if sequence != seq { - error!("ICMP: Discarding sequence {}", sequence); - return None; - } - let elapsed = - Instant::now().sub(send_time.clone()).as_micros() as f64 / 1000.00; - info!( - "ICMP: Reply from {}: time={}ms, seq={}", - dest, elapsed, sequence, - ); - info!( - "ICMP: Reply from {}: time={}ms, seq={}", - dest, elapsed, sequence, - ); - ping_counter - .with(&prometheus::labels! {"result" => "ok", "domain" => domain_name}) - .inc(); - if elapsed as i32 != 0 { - ping_latency_guage - .with(&prometheus::labels! {"domain" => domain_name}) - .set(elapsed); - } - } - _ => { - // We ignore the rest. - return None; - } - } - Some(()) - }; - loop_impl(socket, dest, packet_handler, err_handler, stop_signal); - } + let mut v6_destinations = HashMap::new(); + let mut v6_id_counter = 42; + for target in v6_targets { + info!("ICMP: Attempting ping to {}({})", target.0, target.1); + v6_destinations.insert(v6_id_counter, target.clone()); + v6_id_counter += 1; + } + let mut v4_pinger = PingerImpl { + sock: IcmpSocket4::new().expect("Failed to open Icmpv4 Socket"), + timeout: Duration::from_secs(1), }; + let mut v6_state = State { + sequence: 0, + destinations: v6_destinations, + time_tracker: HashMap::new(), + latency_guage: ping_latency_guage, + ping_counter, + stop_signal: stop_signal.clone(), + }; + let mut v6_pinger = PingerImpl { + sock: IcmpSocket6::new().expect("Failed to open Icmpv6 Socket"), + timeout: Duration::from_secs(1), + }; + loop { + v4_pinger + .send_all(&mut v4_state) + .expect("Error sending packets on socket"); + v6_pinger + .send_all(&mut v6_state) + .expect("Error sending packets on socket"); + v4_pinger.recv_all(&mut v4_state); + v6_pinger.recv_all(&mut v6_state); + { + // Scope the lock really tightly + if *stop_signal.read().unwrap() { + return; + } + } + std::thread::sleep(Duration::from_secs(PINGDELAY.flag)) + } } diff --git a/src/main.rs b/src/main.rs index 935400b..4d60a39 100644 --- a/src/main.rs +++ b/src/main.rs @@ -180,13 +180,17 @@ fn main() -> anyhow::Result<()> { }); parent.adopt(Box::new(render_thread)); } - for domain_name in ping_hosts.iter().cloned() { - // TODO(Prometheus stats) + { let stop_signal = stop_signal.clone(); let ping_latency_vec = ping_latency_vec.clone(); let ping_counter_vec = ping_counter_vec.clone(); let ping_thread = thread::Pending::new(move || { - icmp::start_echo_loop(domain_name, stop_signal, ping_latency_vec, ping_counter_vec); + icmp::start_echo_loop( + &ping_hosts, + stop_signal.clone(), + ping_latency_vec, + ping_counter_vec, + ); }); parent.schedule(Box::new(ping_thread)); }