diff --git a/examples/serenity/voice/src/main.rs b/examples/serenity/voice/src/main.rs index 53f8e5338..8059b7f2c 100644 --- a/examples/serenity/voice/src/main.rs +++ b/examples/serenity/voice/src/main.rs @@ -278,15 +278,7 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { }, }; - if !url.starts_with("http") { - check_msg( - msg.channel_id - .say(&ctx.http, "Must provide a valid URL") - .await, - ); - - return Ok(()); - } + let do_search = !url.starts_with("http"); let guild_id = msg.guild_id.unwrap(); @@ -305,8 +297,12 @@ async fn play(ctx: &Context, msg: &Message, mut args: Args) -> CommandResult { if let Some(handler_lock) = manager.get(guild_id) { let mut handler = handler_lock.lock().await; - let src = YoutubeDl::new(http_client, url); - let _ = handler.play_input(src.into()); + let mut src = if do_search { + YoutubeDl::new_search(http_client, url) + } else { + YoutubeDl::new(http_client, url) + }; + let _ = handler.play_input(src.clone().into()); check_msg(msg.channel_id.say(&ctx.http, "Playing song").await); } else { diff --git a/src/input/sources/ytdl.rs b/src/input/sources/ytdl.rs index d7827e7cc..35a3e7516 100644 --- a/src/input/sources/ytdl.rs +++ b/src/input/sources/ytdl.rs @@ -18,6 +18,12 @@ use tokio::process::Command; const YOUTUBE_DL_COMMAND: &str = "yt-dlp"; +#[derive(Clone, Debug)] +enum QueryType { + Url(String), + Search(String), +} + /// A lazily instantiated call to download a file, finding its URL via youtube-dl. /// /// By default, this uses yt-dlp and is backed by an [`HttpRequest`]. This handler @@ -30,7 +36,7 @@ pub struct YoutubeDl { program: &'static str, client: Client, metadata: Option, - url: String, + query: QueryType, } impl YoutubeDl { @@ -52,14 +58,63 @@ impl YoutubeDl { program, client, metadata: None, - url, + query: QueryType::Url(url), + } + } + + /// Creates a request to search youtube for an optionally specified number of videos matching `query`, + /// using "yt-dlp". + #[must_use] + pub fn new_search(client: Client, query: String) -> Self { + Self::new_search_ytdl_like(YOUTUBE_DL_COMMAND, client, query) + } + + /// Creates a request to search youtube for an optionally specified number of videos matching `query`, + /// using `program`. + #[must_use] + pub fn new_search_ytdl_like(program: &'static str, client: Client, query: String) -> Self { + Self { + program, + client, + metadata: None, + query: QueryType::Search(query), } } - async fn query(&mut self) -> Result { + /// Runs a search for the given query, returning a list of up to `n_results` + /// possible matches which are `AuxMetadata` objects containing a valid URL. + /// + /// Returns up to 5 matches by default. + pub async fn search( + &mut self, + n_results: Option, + ) -> Result, AudioStreamError> { + let n_results = n_results.unwrap_or(5); + + Ok(match &self.query { + // Safer to just return the metadata for the pointee if possible + QueryType::Url(_) => vec![self.aux_metadata().await?], + QueryType::Search(_) => self + .query(n_results) + .await? + .into_iter() + .map(|v| v.as_aux_metadata()) + .collect(), + }) + } + + async fn query(&mut self, n_results: usize) -> Result, AudioStreamError> { + let new_query; + let query_str = match &self.query { + QueryType::Url(url) => url, + QueryType::Search(query) => { + new_query = format!("ytsearch{n_results}:{query}"); + &new_query + }, + }; let ytdl_args = [ "-j", - &self.url, + query_str, "-f", "ba[abr>0][vcodec=none]/best", "--no-playlist", @@ -77,14 +132,35 @@ impl YoutubeDl { }) })?; - // NOTE: must be mut for simd-json. - #[allow(clippy::unnecessary_mut_passed)] - let stdout: Output = crate::json::from_slice(&mut output.stdout[..]) + if !output.status.success() { + return Err(AudioStreamError::Fail( + format!( + "{} failed with non-zero status code: {}", + self.program, + std::str::from_utf8(&output.stderr[..]).unwrap_or("") + ) + .into(), + )); + } + + // NOTE: must be split_mut for simd-json. + let out = output + .stdout + .split_mut(|&b| b == b'\n') + .filter_map(|x| (!x.is_empty()).then(|| crate::json::from_slice(x))) + .collect::, _>>() .map_err(|e| AudioStreamError::Fail(Box::new(e)))?; - self.metadata = Some(stdout.as_aux_metadata()); + let meta = out + .first() + .ok_or_else(|| { + AudioStreamError::Fail(format!("no results found for '{query_str}'").into()) + })? + .as_aux_metadata(); - Ok(stdout) + self.metadata = Some(meta); + + Ok(out) } } @@ -103,11 +179,13 @@ impl Compose for YoutubeDl { async fn create_async( &mut self, ) -> Result>, AudioStreamError> { - let stdout = self.query().await?; + // panic safety: `query` should have ensured > 0 results if `Ok` + let mut results = self.query(1).await?; + let result = results.swap_remove(0); let mut headers = HeaderMap::default(); - if let Some(map) = stdout.http_headers { + if let Some(map) = result.http_headers { headers.extend(map.iter().filter_map(|(k, v)| { Some(( HeaderName::from_bytes(k.as_bytes()).ok()?, @@ -118,9 +196,9 @@ impl Compose for YoutubeDl { let mut req = HttpRequest { client: self.client.clone(), - request: stdout.url, + request: result.url, headers, - content_length: stdout.filesize, + content_length: result.filesize, }; req.create_async().await @@ -135,7 +213,7 @@ impl Compose for YoutubeDl { return Ok(meta.clone()); } - self.query().await?; + self.query(1).await?; self.metadata.clone().ok_or_else(|| { let msg: Box = @@ -185,4 +263,25 @@ mod tests { assert!(ytdl.aux_metadata().await.is_err()); } + + #[tokio::test] + #[ntest::timeout(20_000)] + async fn ytdl_search_plays() { + let mut ytdl = YoutubeDl::new_search(Client::new(), "cloudkicker 94 days".into()); + let res = ytdl.search(Some(1)).await; + + let res = res.unwrap(); + assert_eq!(res.len(), 1); + + track_plays_passthrough(move || ytdl).await; + } + + #[tokio::test] + #[ntest::timeout(20_000)] + async fn ytdl_search_3() { + let mut ytdl = YoutubeDl::new_search(Client::new(), "test".into()); + let res = ytdl.search(Some(3)).await; + + assert_eq!(res.unwrap().len(), 3); + } }