1
1
use derive_more:: Display ;
2
2
use serde:: { Deserialize , Serialize } ;
3
3
use std:: num:: NonZeroU64 ;
4
+ use tokio_util:: task:: { task_tracker:: TaskTrackerWaitFuture , TaskTracker } ;
4
5
5
6
/// A [`TaskId`](https://docs.rs/tokio/latest/tokio/task/struct.Id.html) that can be `serde`.
6
7
#[ derive( Debug , Display , Serialize , Deserialize , Copy , Clone , Eq , PartialEq , Hash ) ]
@@ -13,6 +14,20 @@ impl From<tokio::task::Id> for TaskId {
13
14
}
14
15
}
15
16
17
+ /// Execute [`close`](https://docs.rs/tokio-util/latest/tokio_util/task/task_tracker/struct.TaskTracker.html#method.close)
18
+ /// and [`wait`](https://docs.rs/tokio-util/latest/tokio_util/task/task_tracker/struct.TaskTracker.html#method.wait)
19
+ /// for [`TaskTracker`](https://docs.rs/tokio-util/latest/tokio_util/task/task_tracker/struct.TaskTracker.html) at once.
20
+ pub trait CloseAndWait {
21
+ fn close_and_wait ( & self ) -> TaskTrackerWaitFuture ;
22
+ }
23
+
24
+ impl CloseAndWait for TaskTracker {
25
+ fn close_and_wait ( & self ) -> TaskTrackerWaitFuture {
26
+ self . close ( ) ;
27
+ self . wait ( )
28
+ }
29
+ }
30
+
16
31
#[ cfg( test) ]
17
32
mod tests {
18
33
use super :: * ;
@@ -22,4 +37,38 @@ mod tests {
22
37
let id = tokio:: spawn ( async { tokio:: task:: id ( ) } ) . await . unwrap ( ) ;
23
38
assert_eq ! ( id. to_string( ) , TaskId :: from( id) . to_string( ) ) ;
24
39
}
40
+
41
+ fn tracker_spawn ( ) -> TaskTracker {
42
+ let tracker = TaskTracker :: new ( ) ;
43
+
44
+ for i in 0 ..3 {
45
+ tracker. spawn ( async move { i } ) ;
46
+ }
47
+
48
+ tracker
49
+ }
50
+
51
+ #[ tokio:: test]
52
+ async fn close_and_wait ( ) {
53
+ use std:: time:: Duration ;
54
+ use tokio:: time:: timeout;
55
+
56
+ let tracker = tracker_spawn ( ) ;
57
+ assert ! ( timeout( Duration :: from_secs_f64( 1.5 ) , tracker. wait( ) )
58
+ . await
59
+ . is_err( ) ) ;
60
+
61
+ let tracker = tracker_spawn ( ) ;
62
+ tracker. close ( ) ;
63
+ assert ! ( timeout( Duration :: from_secs_f64( 1.5 ) , tracker. wait( ) )
64
+ . await
65
+ . is_ok( ) ) ;
66
+
67
+ let tracker = tracker_spawn ( ) ;
68
+ assert ! (
69
+ timeout( Duration :: from_secs_f64( 1.5 ) , tracker. close_and_wait( ) )
70
+ . await
71
+ . is_ok( )
72
+ ) ;
73
+ }
25
74
}
0 commit comments