diff --git a/README.md b/README.md index 0f04028..c35666b 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,9 @@ that looks like this: [gcm] api_key = "your-api-key" + [apns] + keyfile = "your-keyfile.p8" + Then simply run export RUST_LOG=push_relay=debug,hyper=info diff --git a/src/main.rs b/src/main.rs index 5f60de8..b5d1fe9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -27,6 +27,8 @@ mod errors; mod push; mod server; +use std::fs::File; +use std::io::Read; use std::net::SocketAddr; use std::process; @@ -74,13 +76,34 @@ fn main() { error!("Invalid config file: No [gcm] section in {}", configfile); process::exit(2); }); - let api_key = config_gcm.get("api_key").unwrap_or_else(|| { + let gcm_api_key = config_gcm.get("api_key").unwrap_or_else(|| { error!("Invalid config file: No 'api_key' key in [gcm] section in {}", configfile); process::exit(2); }); + // Determine APNS API keyfile + let config_apns = config.section(Some("apns".to_owned())).unwrap_or_else(|| { + error!("Invalid config file: No [apns] section in {}", configfile); + process::exit(2); + }); + let apns_keyfile_path = config_apns.get("keyfile").unwrap_or_else(|| { + error!("Invalid config file: No 'keyfile' key in [apns] section in {}", configfile); + process::exit(2); + }); + + // Open APNS keyfile + let mut apns_keyfile = File::open(apns_keyfile_path).unwrap_or_else(|e| { + error!("Invalid 'keyfile' path: Could not open '{}': {}", apns_keyfile_path, e); + process::exit(3); + }); + let mut apns_api_key_string = String::new(); + apns_keyfile.read_to_string(&mut apns_api_key_string).unwrap_or_else(|e| { + error!("Invalid 'keyfile' path: Could not read '{}': {}", apns_keyfile_path, e); + process::exit(3); + }); + info!("Starting Push Relay Server {} on {}", VERSION, &addr); - server::serve(api_key, addr).unwrap_or_else(|e| { + server::serve(gcm_api_key, apns_api_key_string, addr).unwrap_or_else(|e| { error!("Could not start relay server: {}", e); process::exit(3); }); diff --git a/src/server.rs b/src/server.rs index 45a71f7..bc3a4c0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,10 +14,11 @@ use ::utils::BoxedFuture; /// Start the server and run infinitely. -pub fn serve( - api_key: S, +pub fn serve( + gcm_api_key: S, + apns_api_key: T, listen_on: SocketAddr, -) -> Result<(), HyperError> where S: ToString { +) -> Result<(), HyperError> where S: ToString, T: ToString { // TODO: CSRF // Create reactor loop @@ -29,7 +30,8 @@ pub fn serve( // Create server let serve = Http::new().serve_addr_handle(&listen_on, &handle1, || { Ok(PushHandler { - api_key: api_key.to_string(), + gcm_api_key: gcm_api_key.to_string(), + apns_api_key: apns_api_key.to_string(), handle: handle2.clone(), }) })?; @@ -45,7 +47,8 @@ pub fn serve( } pub struct PushHandler { - api_key: String, + gcm_api_key: String, + apns_api_key: String, handle: Handle, } @@ -101,7 +104,7 @@ impl Service for PushHandler { }; // Parse request body - let api_key_clone = self.api_key.clone(); + let api_key_clone = self.gcm_api_key.clone(); let handle_clone = self.handle.clone(); Box::new( body @@ -219,11 +222,20 @@ mod tests { ::std::str::from_utf8(&bytes).unwrap().to_string() } + fn get_handler() -> (Core, PushHandler) { + let core = Core::new().unwrap(); + let handler = PushHandler { + gcm_api_key: "aassddff".into(), + apns_api_key: "aassddff".into(), + handle: core.handle(), + }; + (core, handler) + } + /// Handle invalid paths #[test] fn test_invalid_path() { - let mut core = Core::new().unwrap(); - let handler = PushHandler { api_key: "aassddff".into(), handle: core.handle() }; + let (mut core, handler) = get_handler(); let req = Request::new(Method::Post, Uri::from_str("/larifari").unwrap()); let resp = core.run(handler.call(req)).unwrap(); @@ -234,8 +246,7 @@ mod tests { /// Handle invalid methods #[test] fn test_invalid_method() { - let mut core = Core::new().unwrap(); - let handler = PushHandler { api_key: "aassddff".into(), handle: core.handle() }; + let (mut core, handler) = get_handler(); let req = Request::new(Method::Get, Uri::from_str("/push").unwrap()); let resp = core.run(handler.call(req)).unwrap(); @@ -246,8 +257,7 @@ mod tests { /// Handle invalid request content type #[test] fn test_invalid_contenttype() { - let mut core = Core::new().unwrap(); - let handler = PushHandler { api_key: "aassddff".into(), handle: core.handle() }; + let (mut core, handler) = get_handler(); let req = Request::new(Method::Post, Uri::from_str("/push").unwrap()); let resp = core.run(handler.call(req)).unwrap(); @@ -260,8 +270,7 @@ mod tests { /// A request without parameters should result in a HTTP 400 response. #[test] fn test_no_params() { - let mut core = Core::new().unwrap(); - let handler = PushHandler { api_key: "aassddff".into(), handle: core.handle() }; + let (mut core, handler) = get_handler(); let mut req = Request::new(Method::Post, Uri::from_str("/push").unwrap()); req.headers_mut().set(ContentType::form_url_encoded()); @@ -275,8 +284,7 @@ mod tests { /// A request wit missing parameters should result in a HTTP 400 response. #[test] fn test_missing_params() { - let mut core = Core::new().unwrap(); - let handler = PushHandler { api_key: "aassddff".into(), handle: core.handle() }; + let (mut core, handler) = get_handler(); let mut req = Request::new(Method::Post, Uri::from_str("/push").unwrap()); req.headers_mut().set(ContentType::form_url_encoded()); @@ -291,8 +299,7 @@ mod tests { /// A request wit missing parameters should result in a HTTP 400 response. #[test] fn test_bad_token_type() { - let mut core = Core::new().unwrap(); - let handler = PushHandler { api_key: "aassddff".into(), handle: core.handle() }; + let (mut core, handler) = get_handler(); let mut req = Request::new(Method::Post, Uri::from_str("/push").unwrap()); req.headers_mut().set(ContentType::form_url_encoded()); @@ -317,8 +324,7 @@ mod tests { }"#) .create(); - let mut core = Core::new().unwrap(); - let handler = PushHandler { api_key: "aassddff".into(), handle: core.handle() }; + let (mut core, handler) = get_handler(); let mut req = Request::new(Method::Post, Uri::from_str("/push").unwrap()); req.headers_mut().set(ContentType::form_url_encoded());