Commit d2d9129d authored by Benjamin Lee's avatar Benjamin Lee 💬
Browse files

A lot of refactoring to make network error handling more correct.

parent 16b16945
......@@ -18,7 +18,6 @@ log = { version = "0.4.6", features = ["max_level_trace", "release_max_level_inf
renderdoc = { version = "0.4.0", optional = true }
palette = { version = "0.4.1", features = ["serde"] }
arrayvec = { version = "0.4.8", features = ["use_union"] }
take_mut = "0.2.2"
parking_lot = "0.7.0"
bincode = "1.0.1"
serde = "1.0.81"
......@@ -38,6 +37,7 @@ nalgebra = { version = "0.16.12", features = ["serde-serialize"] }
either = "1.5.0"
crossbeam = "0.6.0"
enum-kinds = "0.4.1"
take_mut = "0.2.2"
[build-dependencies]
shaderc = "0.3.12"
......
......@@ -9,6 +9,7 @@ use crate::networking::server::ServerPacket;
use crate::networking::tick::Interval;
use crate::networking::{
Error,
RecvError,
RttEstimator,
CONNECTION_TIMEOUT,
MAX_PACKET_SIZE,
......@@ -24,8 +25,6 @@ use serde::{Deserialize, Serialize};
use std::collections::VecDeque;
use std::io::{self, Cursor};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::thread;
use std::time::{Duration, Instant};
......@@ -57,10 +56,11 @@ pub enum ClientPacket {
pub enum ClientState {
Connecting {
done: Sender<Result<Game, Error>>,
done: Sender<Result<(Game, ConnectedHandle), Option<Error>>>,
cursor: Point2<f32>,
},
Connected {
done: Sender<Option<Error>>,
tick: Interval,
rtt: RttEstimator,
ping: Interval,
......@@ -78,7 +78,6 @@ pub struct Client {
connection: Connection,
state: ClientState,
_shutdown: Registration,
disconnected: Arc<AtomicBool>,
/// Marks after `shutdown` has been received, to shutdown when the
/// `send_queue` is empty.
needs_shutdown: bool,
......@@ -86,14 +85,14 @@ pub struct Client {
stats: NetworkStats,
}
pub type ConnectingHandle =
Receiver<Result<(Game, ConnectedHandle), Option<Error>>>;
pub type ConnectedHandle = Receiver<Option<Error>>;
/// Client handle used while connecting to a sever.
pub struct ClientHandle {
shutdown: SetReadiness,
disconnected: Arc<AtomicBool>,
}
pub struct ConnectingHandle {
done: Receiver<Result<Game, Error>>,
}
pub fn connect(
......@@ -103,26 +102,17 @@ pub fn connect(
) -> Result<(ClientHandle, ConnectingHandle), Error> {
let (done_tx, done_rx) = channel::bounded(1);
let (shutdown_registration, shutdown_set_readiness) = Registration::new2();
let disconnected = Arc::new(AtomicBool::new(false));
let mut client = Client::new(
addr,
done_tx,
stats,
shutdown_registration,
disconnected.clone(),
cursor,
)?;
let client =
Client::new(addr, done_tx, stats, shutdown_registration, cursor)?;
thread::spawn(move || {
run_event_loop(&mut client);
run_event_loop(client);
info!("client done");
});
Ok((
ClientHandle {
shutdown: shutdown_set_readiness,
disconnected,
},
ConnectingHandle {
done: done_rx,
},
done_rx,
))
}
......@@ -130,28 +120,13 @@ impl ClientHandle {
/// Signals the client thread to shutdown.
pub fn shutdown(&self) {
if let Err(err) = self.shutdown.set_readiness(Ready::readable()) {
warn!("failed to signal shutdown to client: {}", err)
error!("failed to signal shutdown to client: {}", err)
}
}
/// Checks if the client has disconnected.
pub fn disconnected(&self) -> bool {
self.disconnected.load(Ordering::SeqCst)
}
}
impl ConnectingHandle {
/// Gets the connection result, if connection finished.
pub fn done(&mut self) -> Option<Result<Game, Error>> {
match self.done.try_recv() {
Ok(done) => Some(done),
Err(_) => None,
}
}
}
// Gracefully shutdown when the handle is lost.
impl Drop for ClientHandle {
/// Gracefully shutdown when the handle is lost.
fn drop(&mut self) {
self.shutdown();
}
......@@ -166,10 +141,19 @@ impl EventHandler for Client {
match event.token() {
SOCKET => {
if event.readiness().is_readable() {
self.socket_readable();
// Don't process any new messages while shutting down.
if self.needs_shutdown {
return false;
}
if let Err(err) = self.socket_readable() {
return self.start_shutdown(Some(err));
}
}
if event.readiness().is_writable() {
self.socket_writable();
if let Err(err) = self.socket_writable() {
return self.start_shutdown(Some(err));
}
}
if self.send_queue.is_empty() && self.needs_shutdown {
......@@ -179,16 +163,21 @@ impl EventHandler for Client {
}
},
TIMER => {
// Don't respond to timer events while shutting down.
if self.needs_shutdown {
return false;
}
while let Some(timeout) = self.timer.poll() {
match timeout {
TimeoutState::Ping => {
if let Err(err) = self.send_ping() {
error!("error sending ping packet: {}", err);
return self.start_shutdown(Some(err));
}
},
TimeoutState::Tick => {
if let Err(err) = self.send_tick() {
error!("error sending tick packet: {}", err);
return self.start_shutdown(Some(err));
}
},
TimeoutState::UpdateStats => {
......@@ -209,28 +198,14 @@ impl EventHandler for Client {
);
},
TimeoutState::LostConnection => {
self.disconnected.store(true, Ordering::Relaxed);
info!("client connection timed out");
if let ClientState::Connecting {
ref mut done,
..
} = self.state
{
let _ = done.send(Err(Error::TimedOut));
}
return true;
return self.start_shutdown(Some(Error::TimedOut));
},
}
}
},
SHUTDOWN => {
info!("client started shutdown");
if let Err(err) = self.send(&ClientPacket::Disconnect) {
error!("failed to send disconnect packet: {}", err);
// If this errored, just shut down immediately.
return true;
}
self.needs_shutdown = true;
return self.start_shutdown(None);
},
Token(_) => unreachable!(),
}
......@@ -242,25 +217,34 @@ impl EventHandler for Client {
impl Client {
pub fn new(
addr: SocketAddr,
done: Sender<Result<Game, Error>>,
done: Sender<Result<(Game, ConnectedHandle), Option<Error>>>,
stats: Sender<NetworkStats>,
shutdown: Registration,
disconnected: Arc<AtomicBool>,
cursor: Point2<f32>,
) -> Result<Client, Error> {
let socket = UdpSocket::bind(&"0.0.0.0:0".parse().unwrap())
.map_err(Error::bind_socket)?;
socket.connect(addr).map_err(Error::connect_socket)?;
let socket =
UdpSocket::bind(&"0.0.0.0:0".parse().unwrap()).map_err(|err| {
Error::BindSocket {
addr,
err,
}
})?;
socket.connect(addr).map_err(|err| {
Error::ConnectSocket {
addr,
err,
}
})?;
let mut timer = timer::Builder::default()
.tick_duration(Duration::from_millis(10))
.build();
let poll = Poll::new().map_err(Error::poll_init)?;
let poll = Poll::new().map_err(Error::poll)?;
poll.register(&socket, SOCKET, Ready::readable(), PollOpt::edge())
.map_err(Error::poll_register)?;
.map_err(Error::poll)?;
poll.register(&timer, TIMER, Ready::readable(), PollOpt::edge())
.map_err(Error::poll_register)?;
.map_err(Error::poll)?;
poll.register(&shutdown, SHUTDOWN, Ready::readable(), PollOpt::edge())
.map_err(Error::poll_register)?;
.map_err(Error::poll)?;
let timeout =
timer.set_timeout(CONNECTION_TIMEOUT, TimeoutState::LostConnection);
......@@ -278,7 +262,6 @@ impl Client {
done,
cursor,
},
disconnected,
_shutdown: shutdown,
stats_tx: stats,
stats: NetworkStats::default(),
......@@ -293,7 +276,60 @@ impl Client {
Ok(client)
}
pub fn socket_readable(&mut self) {
/// Starts shutting down the networking thread, with a provided reason.
///
/// If any errors occur at this point, returns `true` to indicate
/// that the event loop should hard-shutdown immediately.
#[must_use]
fn start_shutdown(&mut self, reason: Option<Error>) -> bool {
// If already shutting down, don't redo this stuff.
if self.needs_shutdown {
return true;
}
match self.state {
ClientState::Connecting {
ref mut done,
..
} => {
let _ = done.send(Err(reason));
},
ClientState::Connected {
ref mut done,
..
} => {
let _ = done.send(reason);
},
}
// Get rid of any pending packets.
self.send_queue.clear();
// Send off a bunch of disconnected packets to the server, in
// the hopes that at least one gets through.
for _ in 0..8 {
if let Err(err) = self.send(&ClientPacket::Disconnect) {
error!(
"error ocurred while sending disconnect packets: {}",
err
);
return true;
}
}
false
}
fn reregister_socket(&mut self, writable: bool) -> Result<(), Error> {
let readiness = if writable {
Ready::readable() | Ready::writable()
} else {
Ready::readable()
};
self.poll
.reregister(&self.socket, SOCKET, readiness, PollOpt::edge())
.map_err(Error::poll)
}
fn socket_readable(&mut self) -> Result<(), Error> {
loop {
match self.socket.recv(&mut self.recv_buffer) {
Ok(bytes_read) => {
......@@ -305,22 +341,29 @@ impl Client {
);
// Handle packet.
self.stats.bytes_in += bytes_read as u32;
if let Err(err) = self.on_recv(bytes_read) {
error!("{}", err);
if let Err(err) = self.on_recv(bytes_read)? {
error!(
"receiving packet failed ({:?}): {}",
&self.recv_buffer[0..bytes_read],
err
);
}
},
Err(err) => {
if err.kind() != io::ErrorKind::WouldBlock {
error!("error receiving packet on client: {}", err);
return Err(Error::SocketRead(err));
} else {
break;
}
},
}
}
Ok(())
}
pub fn socket_writable(&mut self) {
fn socket_writable(&mut self) -> Result<(), Error> {
while let Some(packet) = self.send_queue.pop_front() {
match self.socket.send(&packet) {
Err(err) => {
......@@ -329,6 +372,7 @@ impl Client {
"error sending packet from client ({:?}): {}",
&packet, err
);
return Err(Error::SocketWrite(err));
} else {
break;
}
......@@ -353,19 +397,9 @@ impl Client {
if self.send_queue.is_empty() {
// No longer care about writable events if there are no
// more packets to send.
if let Err(err) = self
.poll
.reregister(
&self.socket,
SOCKET,
Ready::readable(),
PollOpt::edge(),
)
.map_err(Error::poll_register)
{
error!("{}", err);
}
self.reregister_socket(false)?;
}
Ok(())
}
fn send_ping(&mut self) -> Result<(), Error> {
......@@ -408,14 +442,20 @@ impl Client {
Ok(())
}
fn on_recv(&mut self, bytes_read: usize) -> Result<(), Error> {
fn on_recv(
&mut self,
bytes_read: usize,
) -> Result<Result<(), RecvError>, Error> {
// Make sure that it fits in recv_buffer
if bytes_read > MAX_PACKET_SIZE {
return Err(Error::PacketTooLarge(bytes_read));
return Ok(Err(RecvError::PacketTooLarge(bytes_read)));
}
let packet = &self.recv_buffer[0..bytes_read];
let (packet, sequence, _, lost) =
self.connection.decode(Cursor::new(packet))?;
match self.connection.decode(Cursor::new(packet)) {
Ok(result) => result,
Err(err) => return Ok(Err(err)),
};
self.stats.packets_lost += lost.len() as u16;
let transition = match self.state {
......@@ -440,11 +480,13 @@ impl Client {
.set_timeout(ping.interval(), TimeoutState::Ping);
// Signal the main thread that connection finished.
done.send(Ok(game)).unwrap();
let (done_tx, done_rx) = channel::bounded(1);
let _ = done.send(Ok((game, done_rx)));
info!("completed connection to server");
// Transition to connected state.
Some(ClientState::Connected {
done: done_tx,
game: game_handle,
tick,
ping,
......@@ -484,30 +526,22 @@ impl Client {
self.state = transition;
}
Ok(())
Ok(Ok(()))
}
fn send(&mut self, contents: &ClientPacket) -> Result<u32, Error> {
// Don't send any additional packets while shutting down.
if self.needs_shutdown {
return Err(Error::ShuttingDown);
panic!("attempted to send packet while already shutting down");
}
let size = bincode::serialized_size(contents)
.map_err(Error::serialize)? as usize;
// Serialization errors are always fatal.
let size = bincode::serialized_size(contents).unwrap() as usize;
let mut packet = Vec::with_capacity(size + HEADER_BYTES);
let sequence = self.connection.send_header(&mut packet)?;
bincode::serialize_into(&mut packet, contents)
.map_err(Error::serialize)?;
let sequence = self.connection.send_header(&mut packet);
bincode::serialize_into(&mut packet, contents).unwrap();
self.send_queue.push_back(packet);
self.poll
.reregister(
&self.socket,
SOCKET,
Ready::readable() | Ready::writable(),
PollOpt::edge(),
)
.map_err(Error::poll_register)?;
self.reregister_socket(true)?;
Ok(sequence)
}
......
use crate::networking::Error;
use crate::networking::RecvError;
use byteorder::{ReadBytesExt, WriteBytesExt, BE};
use serde::de::DeserializeOwned;
use smallvec::SmallVec;
......@@ -29,13 +29,12 @@ impl Acks {
if sequence > self.ack {
// Packet newer than most recent packet, so shift
// everything.
self.ack_bits <<= sequence - self.ack;
self.ack_bits = self.ack_bits.wrapping_shl(sequence - self.ack);
self.ack_bits |= 1;
self.ack = sequence;
} else if self.ack - sequence <= 32 {
// Received a packet newer than this one before, but it's
// still in the 32-packet window, so store it.
self.ack_bits |= 1 << (self.ack - sequence);
} else {
// Received a packet newer than this one before.
self.ack_bits |= 1u32.wrapping_shl(self.ack - sequence);
}
}
......@@ -46,7 +45,7 @@ impl Acks {
if new.ack > self.ack {
// Anything that is outside the range of the new ack can
// be considered lost.
let mask = !(!0 >> (new.ack - self.ack));
let mask = !((!0u32).wrapping_shr(new.ack - self.ack));
lost.extend(
Acks {
ack_bits: !self.ack_bits & mask,
......@@ -56,12 +55,12 @@ impl Acks {
);
// Shift everything.
self.ack_bits <<= new.ack - self.ack;
self.ack_bits = self.ack_bits.wrapping_shl(new.ack - self.ack);
self.ack_bits |= new.ack;
self.ack = new.ack;
} else if self.ack - new.ack <= 32 {
self.ack_bits |= new.ack << (self.ack - new.ack);
};
} else {
self.ack_bits |= new.ack.wrapping_shl(self.ack - new.ack);
}
lost
}
......@@ -71,7 +70,7 @@ impl Acks {
if self.ack < sequence {
return false;
}
self.ack_bits & (1 << (self.ack - sequence)) != 0
self.ack_bits & (1u32.wrapping_shl(self.ack - sequence)) != 0
}
/// Returns an iterator over the acked packets.
......@@ -93,10 +92,12 @@ impl Connection {
pub fn recv_header<B: Read>(
&mut self,
mut packet: B,
) -> Result<(u32, Acks, SmallVec<[u32; 4]>), Error> {
let sequence = packet.read_u32::<BE>().map_err(Error::header_read)?;
let ack = packet.read_u32::<BE>().map_err(Error::header_read)?;
let ack_bits = packet.read_u32::<BE>().map_err(Error::header_read)?;
) -> Result<(u32, Acks, SmallVec<[u32; 4]>), RecvError> {
let sequence =
packet.read_u32::<BE>().map_err(RecvError::header_read)?;
let ack = packet.read_u32::<BE>().map_err(RecvError::header_read)?;
let ack_bits =
packet.read_u32::<BE>().map_err(RecvError::header_read)?;
self.acks.ack(sequence);
let acks = Acks {
......@@ -107,18 +108,17 @@ impl Connection {
Ok((sequence, acks, lost))
}
pub fn send_header<B: Write>(
&mut self,
mut packet: B,
) -> Result<u32, Error> {
/// Writes the header for the next packet into a buffer.
///
/// Panics if the buffer is not large enough or if an IO error
/// occurs while writing.
pub fn send_header<B: Write>(&mut self, mut packet: B) -> u32 {
let sequence = self.local_sequence;
self.local_sequence += 1;
packet.write_u32::<BE>(sequence).map_err(Error::header_write)?;
packet.write_u32::<BE>(self.acks.ack).map_err(Error::header_write)?;
packet
.write_u32::<BE>(self.acks.ack_bits)
.map_err(Error::header_write)?;
Ok(sequence)
packet.write_u32::<BE>(sequence).unwrap();
packet.write_u32::<BE>(self.acks.ack).unwrap();
packet.write_u32::<BE>(self.acks.ack_bits).unwrap();
sequence
}
/// Reads the header of a packet, and then deserializes the
......@@ -127,10 +127,10 @@ impl Connection {
pub fn decode<B: Read, P: DeserializeOwned>(
&mut self,
mut read: B,
) -> Result<(P, u32, Acks, SmallVec<[u32; 4]>), Error> {
) -> Result<(P, u32, Acks, SmallVec<[u32; 4]>), RecvError> {
let (sequence, acks, lost) = self.recv_header(&mut read)?;
let packet =
bincode::deserialize_from(read).map_err(Error::deserialize)?;
bincode::deserialize_from(read).map_err(RecvError::deserialize)?;
Ok((packet, sequence, acks, lost))
}
}
......@@ -9,7 +9,7 @@ pub trait EventHandler {
fn handle(&mut self, event: Event) -> bool;
}
pub fn run_event_loop<T: EventHandler>(handler: &mut T) {
pub fn run_event_loop<T: EventHandler>(mut handler: T) {
let mut events = Events::with_capacity(1024);
'event_loop: loop {
if let Err(err) = handler.poll().poll(&mut events, None) {
......
use bincode;
use failure::{Backtrace, Fail};
use std::io;
use std::net::SocketAddr;
use std::time::{Duration, Instant};
pub mod client;
......@@ -59,66 +61,57 @@ impl RttEstimator {
}
}
/// Non-fatal errors that occur on receiving a packet.
///
/// These should be logged, but generally do not end the connection.
#[derive(Fail, Debug)]
pub enum RecvError {
#[fail(display = "received packet that was too large ({} bytes)", _0)]
PacketTooLarge(usize),
#[fail(display = "reading packet header failed: {} {}", _0, _1)]
HeaderRead(io::Error, Backtrace),
#[fail(display = "deserializing packet payload failed: {} {}", _0, _1)]
Deserialize(bincode::Error, Backtrace),
}
/// (Mostly) fatal errors that should kill either the networking event
/// loop, or the particular connection in question.
#[derive(Fail, Debug)]
pub enum Error {
#[fail(display = "connection timed out")]