// Copyright 2020 Jeremy Wall // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; use std::ops::Sub; use std::{ collections::{BTreeMap, HashMap}, time::{Duration, Instant}, }; use crate::util; use gflags; use icmp_socket::{ packet::{Icmpv4Message, Icmpv6Message, WithEchoRequest}, IcmpSocket, IcmpSocket4, IcmpSocket6, Icmpv4Packet, Icmpv6Packet, }; use nursery::{thread, Nursery}; use prometheus::{CounterVec, GaugeVec}; use tracing::{debug, error, info, instrument, warn}; gflags::define! { /// The payload to use for the ping requests. --pingPayload = "durnitisp" } gflags::define! { /// The timeout for ping requests in seconds. --pingTimeout: u64 = 3 } gflags::define! { /// The delay between ping requests. --pingDelay: u64 = 5 } gflags::define! { /// The size in bytes of the ping requests. --maxHops: u8 = 50 } fn resolve_host_address(host: &str) -> String { format!( "{}", util::resolve_hosts(&vec![host]) .unwrap() .first() .unwrap() .unwrap() ) } struct State { destinations: HashMap, // domain, address // TODO(jwall): This should be a time tracker by both identifier and sequence time_tracker: BTreeMap>, destination_counter: BTreeMap, latency_guage: GaugeVec, ping_counter: CounterVec, } impl State { fn handle_echo_reply(&mut self, identifier: u16, sequence: u16) -> bool { if let Some((domain_name, dest)) = self.destinations.get(&identifier) { let time_tracker = self.time_tracker.get_mut(&identifier); let mut result = false; if let Some(Some(send_time)) = time_tracker.as_ref().map(|m| m.get(&sequence)) { let elapsed = Instant::now().sub(send_time.clone()).as_micros() as f64 / 1000.00; // We make a copy here to avoid the borrow above sticking around for too long. info!( domain=domain_name, %dest, time = elapsed, seq = sequence, "Reply", ); self.ping_counter .with(&prometheus::labels! {"result" => "ok", "domain" => domain_name}) .inc(); if elapsed as i32 != 0 { self.latency_guage .with(&prometheus::labels! {"domain" => domain_name.as_str()}) .set(elapsed); } self.time_tracker .get_mut(&identifier) .and_then(|m| m.remove(&sequence)); result = true; } else { error!(sequence, "Discarding unexpected sequence",); }; // Check all the other sequences to see if they have expired timeouts yet. // Record timeout for the expired sequences. // Remove the timeouts for the expired sequences. let expired_sequences = self.time_tracker.get(&identifier).map(|m| { let mut for_delete = Vec::with_capacity(m.len()); let m = m.clone(); { for (k, send_time) in m.iter() { if Instant::now().sub(*send_time) >= Duration::from_secs(PINGTIMEOUT.flag) { info!( domain=domain_name, %dest, seq = sequence, "Dropped" ); self.ping_counter .with(&prometheus::labels! {"result" => "dropped", "domain" => domain_name}) .inc(); for_delete.push(*k); } } } for_delete }); for k in expired_sequences.unwrap_or_default() { self.time_tracker .get_mut(&identifier) .and_then(|m| m.remove(&k)); } return result; } else { warn!(identifier, "Discarding wrong identifier"); } return false; } } struct PingerImpl { sock: Sock, } trait PacketHandler where AddrType: std::fmt::Display + Copy, PacketType: WithEchoRequest, { fn get_mut_state(&mut self) -> &mut State; fn handle_pkt(&mut self, pkt: PacketType) -> bool; } impl<'a> PacketHandler for &'a mut State { fn get_mut_state(&mut self) -> &mut State { return self; } #[instrument(level = "debug", skip(self))] fn handle_pkt(&mut self, pkt: Icmpv6Packet) -> bool { debug!("handling packet"); match pkt.message { Icmpv6Message::Unreachable { _unused, invoking_packet, } => { match Icmpv6Packet::parse(&invoking_packet) { Ok(Icmpv6Packet { typ: _, code: _, checksum: _, message: Icmpv6Message::EchoRequest { identifier, sequence: _, payload: _, }, }) => { if let Some((domain_name, _addr)) = self.destinations.get(&identifier) { self.ping_counter .with(&prometheus::labels! {"result" => "unreachable", "domain" => domain_name}) .inc(); return true; } } Err(e) => { // We ignore these as well but log it. error!(err = ?e, "Error parsing Unreachable"); } _ => { // We ignore these } }; } Icmpv6Message::ParameterProblem { pointer: _, invoking_packet, } => { match Icmpv6Packet::parse(&invoking_packet) { Ok(Icmpv6Packet { typ: _, code: _, checksum: _, message: Icmpv6Message::EchoRequest { identifier, sequence: _, payload: _, }, }) => { if let Some((domain_name, _addr)) = self.destinations.get(&identifier) { self.ping_counter .with(&prometheus::labels! {"result" => "parameter_problem", "domain" => domain_name}) .inc(); return true; } } Err(e) => { // We ignore these as well but log it. error!(err = ?e, "Error parsing ParameterProblem"); } _ => { // We ignore these } } } Icmpv6Message::EchoReply { identifier, sequence, payload: _, } => { return self.handle_echo_reply(identifier, sequence); } _ => { // We ignore the rest. } } return false; } } impl<'a> PacketHandler for &'a mut State { fn get_mut_state(&mut self) -> &mut State { return self; } #[instrument(level = "debug", skip(self))] fn handle_pkt(&mut self, pkt: Icmpv4Packet) -> bool { debug!("handling packet"); match pkt.message { Icmpv4Message::EchoReply { identifier, sequence, payload: _, } => { return self.handle_echo_reply(identifier, sequence); } _ => { // We ignore the rest. info!("Unhandled packet"); } } return false; } } trait Pinger where AddrType: std::fmt::Display + Copy, PacketType: WithEchoRequest, { fn send_all(&mut self, state: &mut State) -> std::io::Result<()>; fn send_to_destination( &mut self, dest: AddrType, identifier: u16, sequence: u16, ) -> std::io::Result; fn recv_pkt(&mut self) -> std::io::Result; fn recv_all>(&mut self, handler: H); fn send_pkt( &mut self, state: &mut State, identifier: u16, dest: AddrType, domain_name: &String, ) -> std::io::Result<()> { let sequence = *state.destination_counter.entry(identifier).or_insert(0); debug!( domain=domain_name, %dest, sequence, "Sending echo request", ); match self.send_to_destination(dest, identifier, sequence) { Err(e) => { state .ping_counter .with(&prometheus::labels! {"result" => "err", "type" => "send"}) .inc(); error!( domain=domain_name, %dest, err=?e, "Error sending. Trying again later", ); } Ok(send_time) => { state .time_tracker .entry(identifier) .or_insert_with(|| BTreeMap::new()) .insert(sequence, send_time); } } state .destination_counter .get_mut(&identifier) .map(|v| *v = v.wrapping_add(1)); Ok(()) } } impl Pinger for PingerImpl where Sock: IcmpSocket, Sock::AddrType: std::fmt::Display + Copy, Sock::PacketType: WithEchoRequest, { #[instrument(skip_all)] fn send_all(&mut self, state: &mut State) -> std::io::Result<()> { let destinations = state.destinations.clone(); debug!("Attempting to send packets for all domains"); for (identifier, (domain_name, dest)) in destinations.into_iter() { self.send_pkt(state, identifier, dest, &domain_name)?; } debug!("Finished sending for domains"); Ok(()) } fn send_to_destination( &mut self, dest: Sock::AddrType, identifier: u16, sequence: u16, ) -> std::io::Result { let packet = Sock::PacketType::with_echo_request( identifier, sequence, PINGPAYLOAD.flag.as_bytes().to_owned(), ) .unwrap(); let send_time = Instant::now(); self.sock.send_to(dest, packet)?; Ok(send_time) } fn recv_pkt(&mut self) -> std::io::Result { let (response, _addr) = self.sock.rcv_from()?; Ok(response) } #[instrument(skip(self, handler))] fn recv_all>(&mut self, mut handler: H) { if handler.get_mut_state().destinations.is_empty() { debug!("Nothing to send to so skipping for this socket"); return; }; if handler .get_mut_state() .time_tracker .values() .find(|item| !item.is_empty()) .is_none() { // nothing has been sent yet so no need to try to recv packets debug!("Nothing to recieve so skipping for this socket"); return; } self.sock.set_timeout(None); let loop_start_time = Instant::now(); loop { // Receive loop debug!("Attempting to recieve packets on socket"); match self.recv_pkt() { Ok(pkt) => { if handler.handle_pkt(pkt) { // break out of the recv loop debug!("Recieved Packet"); return; } } Err(e) => { error!(err = ?e, "Error receiving packet"); handler .get_mut_state() .ping_counter .with(&prometheus::labels! {"result" => "err", "domain" => "unknown"}) .inc(); return; } } if (Instant::now() - loop_start_time) > Duration::from_secs(PINGTIMEOUT.flag) { info!("Timing out on recieve loop"); return; } } } } struct Multi { v4_state: State, v6_state: State, v4_pinger: PingerImpl, v6_pinger: PingerImpl, } impl Multi { fn send_all(&mut self) { self.v4_pinger .send_all(&mut self.v4_state) .expect("Error sending packets on socket"); self.v6_pinger .send_all(&mut self.v6_state) .expect("Error sending packets on socket"); } fn recv_all(&mut self) { self.v4_pinger.recv_all(&mut self.v4_state); self.v6_pinger.recv_all(&mut self.v6_state); } } #[instrument(name = "ICMP", skip_all)] pub fn schedule_echo_server( domain_names: &Vec<&str>, ping_latency_guage: GaugeVec, ping_counter: CounterVec, parent: &mut Nursery, ) { let resolved: Vec<(String, IpAddr)> = domain_names .iter() .map(|domain_name| { let resolved = resolve_host_address(domain_name); let dest = resolved .parse::() .expect(&format!("Invalid IP Address {}", resolved)); (domain_name.to_string(), dest) }) .collect(); let mut v4_targets: Vec<(String, Ipv4Addr)> = Vec::new(); let mut v6_targets: Vec<(String, Ipv6Addr)> = Vec::new(); for (name, addr) in resolved { match addr { IpAddr::V6(addr) => { v6_targets.push((name, addr)); } IpAddr::V4(addr) => { v4_targets.push((name, addr)); } } } let mut v4_destinations = HashMap::new(); let mut v4_id_counter = 42; for target in v4_targets { info!( domain_name = target.0, address = %target.1, "Attempting ping" ); v4_destinations.insert(v4_id_counter, target.clone()); v4_id_counter += 1; } let v4_state = State { destinations: v4_destinations, time_tracker: BTreeMap::new(), destination_counter: BTreeMap::new(), latency_guage: ping_latency_guage.clone(), ping_counter: ping_counter.clone(), }; let mut v6_destinations = HashMap::new(); let mut v6_id_counter = 42; for target in v6_targets { info!( domain_name = target.0, address = %target.1, "Attempting ping" ); v6_destinations.insert(v6_id_counter, target.clone()); v6_id_counter += 1; } let v4_pinger = PingerImpl { sock: IcmpSocket4::new().expect("Failed to open Icmpv4 Socket"), }; let v6_state = State { destinations: v6_destinations, time_tracker: BTreeMap::new(), destination_counter: BTreeMap::new(), latency_guage: ping_latency_guage, ping_counter, }; let v6_pinger = PingerImpl { sock: IcmpSocket6::new().expect("Failed to open Icmpv6 Socket"), }; let multi = std::sync::Arc::new(std::sync::Mutex::new(Multi { v4_pinger, v6_pinger, v4_state, v6_state, })); let send_multi = multi.clone(); let send_thread = thread::Pending::new(move || { info!("Starting send thread"); loop { { send_multi.lock().unwrap().send_all(); } std::thread::sleep(Duration::from_secs(1)); } }); let recv_thread = thread::Pending::new(move || { info!("Starting recv thread"); loop { { multi.lock().unwrap().recv_all(); } std::thread::sleep(Duration::from_millis(5)); } }); parent.schedule(Box::new(send_thread)); parent.schedule(Box::new(recv_thread)); }