-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
369 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,369 @@ | ||
(* -------------------------------------------------------------------- *) | ||
require import AllCore List Distr DList Number StdOrder StdBigop. | ||
require import RealSeries. | ||
require (*--*) DynMatrix. | ||
(*---*) import IntOrder RealOrder RField Bigint Bigreal. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
clone import DynMatrix as DM. | ||
(*-*) import DM.ZR. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
abbrev "_.[_]" ['a] (xs : 'a list) (i : int) = nth<:'a> witness xs i. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma compE ['a 'b 'c] (f : 'a -> 'b) (g : 'b -> 'c) (x : 'a) : | ||
(g \o f) x = g (f x). | ||
proof. done. qed. | ||
hint simplify compE. | ||
(* -------------------------------------------------------------------- *) | ||
lemma dlist_ubound (n : int) (d : R distr) (E : R -> bool) : 0 <= n => | ||
mu | ||
(dlist d n) | ||
(fun xs => exists i, 0 <= i < n /\ E xs.[i]) | ||
<= n%r * mu d E. | ||
proof. | ||
elim: n => /= [|n ge0_n ih]; first by rewrite dlist0 // dunitE //#. | ||
rewrite dlistS //= dmapE /(\o) /=. | ||
pose P1 (x : R) := E x. | ||
pose P2 (xs : R list) := exists i, (0 <= i < n /\ E xs.[i]). | ||
pose P (x_xs : R * R list) := P1 x_xs.`1 \/ P2 x_xs.`2. | ||
rewrite (mu_eq_support _ _ P). | ||
- case=> [x xs] /supp_dprod /= [_]. | ||
case/(supp_dlist _ _ _ ge0_n) => [sz_xs _]. | ||
rewrite /P /=; apply/eq_iff; split; first smt(). | ||
case=> [Ex|]; first exists 0; smt(). | ||
by case=> i rg_i; exists (i+1) => //#. | ||
apply: (ler_trans _ _ _ (le_dprod_or _ _ _ _)). | ||
rewrite fromintD mulrDl /= addrC ler_add. | ||
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)). | ||
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)). | ||
qed. | ||
(* -------------------------------------------------------------------- *) | ||
lemma L ['a 'b 'c 'd 'e 'ab 'ac 'bd 'cd] | ||
(da : 'a distr) (db : 'b distr) (dc : 'c distr) (dd : 'd distr) | ||
(Fab : 'a * 'b -> 'ab) | ||
(Fcd : 'c * 'd -> 'cd) | ||
(F : 'ab -> 'cd -> 'e) | ||
(Fac : 'a * 'c -> 'ac) | ||
(Fbd : 'b * 'd -> 'bd) | ||
(G : 'ac -> 'bd -> 'e) | ||
: | ||
(forall a b c d, F (Fab (a, b)) (Fcd (c, d)) = G (Fac (a, c)) (Fbd (b, d))) => | ||
|
||
dlet | ||
(dmap (da `*` db) Fab) | ||
(fun ab => | ||
dmap | ||
(dmap (dc `*` dd) Fcd) | ||
(fun cd => F ab cd)) | ||
= dlet | ||
(dmap (da `*` dc) Fac) | ||
(fun ac => | ||
dmap | ||
(dmap (db `*` dd) Fbd) | ||
(fun bd => G ac bd)). | ||
proof. | ||
pose D1 := dlet (da `*` db) | ||
(fun ab => dlet dc (fun c => dmap dd (fun d => F (Fab ab) (Fcd (c, d))))). | ||
move=> eq; rewrite dlet_dmap /= &(eq_trans _ D1) /D1 => {D1}. | ||
- by rewrite &(eq_dlet) // => ab /=; rewrite dmap_comp dmap_dprodE. | ||
pose D2 := dlet (da `*` dc) | ||
(fun ac => dlet db (fun b => dmap dd (fun d => G (Fac ac) (Fbd (b, d))))). | ||
rewrite dlet_dmap /= &(eq_trans _ D2) /D2 => {D2}; last first. | ||
- by rewrite &(eq_dlet) // => ac /=; rewrite dmap_comp dmap_dprodE. | ||
rewrite !dprod_dlet !dlet_dlet /= &(eq_dlet) // => a /=. | ||
rewrite dlet_dlet /= dlet_swap &(eq_dlet) // => b /=. | ||
rewrite 2!(dlet_dunit, dlet_unit) /= dlet_dmap. | ||
rewrite &(eq_dlet) // => c /=; rewrite &(eq_dmap) // => d /=. | ||
by apply: eq. | ||
qed. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma L2 ['a 'b 'c 'd 'e] | ||
(da : 'a distr) (db : 'b distr) (dc : 'c distr) (dd : 'd distr) | ||
(F : 'a -> 'b -> 'c -> 'd -> 'e) | ||
: | ||
dlet | ||
(da `*` db) | ||
(fun ab : 'a * 'b => | ||
dmap | ||
(dc `*` dd) | ||
(fun cd : 'c * 'd => F ab.`1 ab.`2 cd.`1 cd.`2)) | ||
= dlet | ||
(da `*` dc) | ||
(fun ac : 'a * 'c => | ||
dmap | ||
(db `*` dd) | ||
(fun bd : 'b * 'd => F ac.`1 bd.`1 ac.`2 bd.`2)). | ||
proof. | ||
pose F1 (ab : 'a * 'b) (cd : 'c * 'd) := F ab.`1 ab.`2 cd.`1 cd.`2. | ||
pose F2 (ac : 'a * 'c) (bd : 'b * 'd) := F ac.`1 bd.`1 ac.`2 bd.`2. | ||
have := L da db dc dd idfun idfun F1 idfun idfun F2 _; first done. | ||
by rewrite !dmap_id. | ||
qed. | ||
(* -------------------------------------------------------------------- *) | ||
lemma dprod_dunit ['a 'b] (x : 'a) (y : 'b) : | ||
dunit x `*` dunit y = dunit (x, y). | ||
proof. | ||
by apply: eq_distr => -[a b]; rewrite dprod1E !dunit1E /#. | ||
qed. | ||
(* -------------------------------------------------------------------- *) | ||
op dadd (d1 d2 : R distr) = | ||
dmap (d1 `*` d2) (fun xy : R * R => xy.`1 + xy.`2). | ||
(* -------------------------------------------------------------------- *) | ||
lemma dlistD (n : int) (d1 d2 : R distr) : 0 <= n => | ||
dlet (dlist d1 n) (fun (xs : R list) => | ||
dmap (dlist d2 n) (fun (ys : R list) => | ||
mkseq (fun i => xs.[i] + ys.[i]) n)) | ||
= dlist (dadd d1 d2) n. | ||
proof. | ||
pose S n (xs ys : R list) := mkseq (fun i => xs.[i] + ys.[i]) n. | ||
pose T n (xs : R list * R list) := S n xs.`1 xs.`2. | ||
move=> ge0_n; rewrite -(dmap_dprodE _ _ (T n)). (* SLOW *) | ||
elim: n ge0_n => /= [|n ge0_n ih]; last rewrite !dlistS //. | ||
- by rewrite !dlist0 // dprod_dunit dmap_dunit /T /S /= mkseq0. | ||
pose C (x_xs : R * R list) := x_xs.`1 :: x_xs.`2. | ||
pose F (x : R * R) (xs : R list * R list) := | ||
S (n+1) (x.`1 :: xs.`1) (x.`2 :: xs.`2). | ||
pose G (xs ys : R list) := S (n+1) xs ys. | ||
rewrite dmap_dprodE; have -> := | ||
L d1 (dlist d1 n) d2 (dlist d2 n) C C G idfun idfun F _; first by done. | ||
rewrite !dmap_id /= dmap_dprodE {1}/dadd dlet_dmap. | ||
apply/eq_dlet => // -[x y] /=; rewrite -ih. | ||
rewrite dmap_comp &(eq_dmap) => -[xs ys] /=. | ||
by rewrite /F /S /C /T /= mkseqSr //= &(eq_in_mkseq) //#. | ||
qed. | ||
(* -------------------------------------------------------------------- *) | ||
lemma dlist_dlist ['a] (d : 'a distr) (m n : int) : | ||
0 <= m => 0 <= n => | ||
dmap (dlist (dlist d m) n) flatten = dlist d (m * n). | ||
proof. | ||
move=> ge0_m; elim: n => /= [|n ge0_n ih]. | ||
- by rewrite !dlist0 // dmap_dunit. | ||
rewrite mulrDr /= [dlist d (m * n + m)]dlist_add //. | ||
- by apply: IntOrder.mulr_ge0. | ||
rewrite dlistSr //= dmap_comp !dmap_dprodE /=. | ||
rewrite -ih dlet_dmap /= &(eq_dlet) // => xss /=. | ||
by rewrite &(eq_dmap) => xs /=; rewrite flatten_rcons. | ||
qed. | ||
(* -------------------------------------------------------------------- *) | ||
lemma dmatrix_dlist (r c : int) (d : R distr) : | ||
0 <= r => 0 <= c => dmatrix d r c = | ||
dmap | ||
(dlist d (r * c)) | ||
(fun vs => offunm ((fun i j => vs.[j * r + i]), r, c)). | ||
proof. | ||
move=> ge0_r ge0_c @/dmatrix @/dvector. | ||
rewrite dlist_dmap dmap_comp !lez_maxr //. | ||
rewrite -dlist_dlist // dmap_comp &(eq_dmap_in) => xss /=. | ||
case/(supp_dlist _ _ _ ge0_c) => size_xss /allP xssE. | ||
have {xssE} xssE: forall xs, xs \in xss => size xs = r. | ||
- by move=> xs /xssE /(supp_dlist _ _ _ ge0_r). | ||
apply/eq_matrixP=> @/ofcols /= i j []. | ||
rewrite !lez_maxr // => rgi rgj. | ||
rewrite !get_offunm /= ?lez_maxr //. | ||
rewrite (nth_map witness) 1:/#. | ||
rewrite (get_oflist witness) 1:#smt:(mem_nth). | ||
rewrite -nth_flatten ~-1:#smt:(mem_nth); do 2! congr. | ||
rewrite sumzE BIA.big_map predT_comp /(\o) /=. | ||
pose D := BIA.big predT (fun _ => r) (take j xss). | ||
apply: (eq_trans _ D) => @/D. | ||
- rewrite !BIA.big_seq &(BIA.eq_bigr) //=. | ||
by move=> xs /mem_take /xssE. | ||
by rewrite big_constz count_predT size_take //#. | ||
qed. | ||
(* -------------------------------------------------------------------- *) | ||
lemma dmatrixD (r c : int) (d1 d2 : R distr) : 0 <= r => 0 <= c => | ||
dlet (dmatrix d1 r c) (fun (m1 : matrix) => | ||
dmap (dmatrix d2 r c) (fun (m2 : matrix) => m1 + m2)) | ||
= dmatrix (dadd d1 d2) r c. | ||
proof. | ||
move=> ge0_r ge0_c; rewrite 2?dmatrix_dlist //=. | ||
pose F vs := offunm (fun i j => vs.[j * r + i], r, c). | ||
rewrite dlet_dmap /= dlet_swap dlet_dmap /= dlet_swap /=. | ||
rewrite dmatrix_dlist // -/F -dlistD ~-1:/#. | ||
rewrite dmap_dlet &(eq_dlet) // => xs /=. | ||
rewrite dlet_dunit dmap_comp &(eq_dmap) => ys /=. | ||
apply/eq_matrixP; split. | ||
- by rewrite /F size_addm !size_offunm. | ||
move=> i j []; rewrite rows_addm cols_addm /=. | ||
rewrite !rows_offunm !cols_offunm !maxzz => rgi rgj. | ||
by rewrite get_addm !get_offunm //= nth_mkseq //#. | ||
qed. | ||
(* -------------------------------------------------------------------- *) | ||
op dmul (n : int) (d1 d2 : R distr) = | ||
dmap | ||
(dlist d1 n `*` dlist d2 n) | ||
(fun vs : R list * R list => | ||
DM.Big.BAdd.big predT | ||
(fun xy : R * R => xy.`1 * xy.`2) | ||
(zip vs.`1 vs.`2)). | ||
(* -------------------------------------------------------------------- *) | ||
lemma foo ['a 'b 'c] (da : 'a distr) (db : 'b distr) (f : 'a -> 'c) : | ||
dmap (da `*` db) (fun ab : 'a * 'b => f ab.`1) | ||
= weight db \cdot dmap da f. | ||
proof. by rewrite dmap_dprodE_swap /= dlet_cst_weight. qed. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma dmatrix_cols (d : R distr) (r c : int) : 0 <= c => 0 <= r => | ||
dmatrix d r c = dmap (dlist (dvector d r) c) (ofcols r c). | ||
proof. by move=> ge0_c ge0_r @/dmatrix; rewrite lez_maxr. qed. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma dmatrix_rows (d : R distr) (r c : int) : 0 <= c => 0 <= r => | ||
dmatrix d r c = dmap (dlist (dvector d c) r) (trmx \o ofcols c r). | ||
proof. | ||
move=> ge0_r ge0_c; rewrite -dmap_comp -dmatrix_cols //. | ||
apply/eq_distr => /= m; rewrite (dmap1E _ trmx). | ||
have ->: pred1 m \o trmx = pred1 (trmx m) by smt(trmxK). | ||
case: (size m = (r, c)); last first. | ||
- by move=> ne_size; rewrite !dmatrix0E //#. | ||
case=> <<- <<-; rewrite -{2}rows_tr -{2}cols_tr !dmatrix1E /=. | ||
by rewrite BRM.exchange_big. | ||
qed. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
hint simplify drop0, take0, cats0, cat0s. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
(* FIXME: refactor *) | ||
|
||
lemma dlist_insert ['a] (i n : int) (d : 'a distr) : | ||
0 <= n => 0 <= i <= n => dlist d (n+1) = | ||
dmap (d `*` dlist d n) (fun x_xs : 'a * 'a list => insert x_xs.`1 x_xs.`2 i). | ||
proof. | ||
move=> ge0_n [ge0_i lti]; apply/eq_sym. | ||
pose f (x_xs : _ * _) := insert x_xs.`1 x_xs.`2 i. | ||
pose g (xs : 'a list) := (xs.[i], take i xs ++ drop (i+1) xs). | ||
have ge0_Sn: 0 <= n + 1 by smt(). apply: (dmap_bij _ _ f g). | ||
- case=> [x xs] /supp_dprod[/=] x_in_d. | ||
case/(supp_dlist _ _ _ ge0_n)=> sz_xs /allP xs_in_d. | ||
move=> @/f /=; apply/supp_dlist; first smt(). | ||
rewrite size_insert ?sz_xs //=; apply/allP. | ||
by move=> y /mem_insert[->>//|/xs_in_d]. | ||
- move=> xs /(supp_dlist _ _ _ ge0_Sn)[sz_xs /allP xs_in_d] @/g. | ||
rewrite dprod1E !dlist1E ~-1://# sz_xs /=. | ||
rewrite size_cat size_take // size_drop 1:/#. | ||
rewrite iftrue 1:/# -(BRM.big_consT (mu1 d)) &(BRM.eq_big_perm). | ||
by rewrite -cat_cons perm_eq_sym &(perm_eq_nth_take_drop) //#. | ||
- case=> x xs /supp_dprod[/=] _ /(supp_dlist _ _ _ ge0_n)[sz_xs _]. | ||
rewrite /g /f /= nth_insert ?sz_xs //= take_insert_le 1:/#. | ||
by rewrite drop_insert_gt 1:/# /= cat_take_drop. | ||
- move=> xs /(supp_dlist _ _ _ ge0_Sn)[/=] sz_xs _ @/f @/g /=. | ||
have sz_take: size (take i xs) = i by rewrite size_take //#. | ||
by apply/insert_nth_take_drop => //#. | ||
qed. | ||
|
||
hint simplify insert0. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma dmatrix_cols_i (i : int) (d : R distr) (r c : int) : | ||
0 <= c => 0 <= r => 0 <= i < c => | ||
dmatrix d r c = | ||
dmap | ||
(dvector d r `*` dlist (dvector d r) (c-1)) | ||
(fun c_cs : _ * _ => ofcols r c (insert c_cs.`1 c_cs.`2 i)). | ||
proof. | ||
move=> ge0_c ge0_r rgi; rewrite dmatrix_cols //. | ||
rewrite {1}(_ : c = (c - 1) + 1) // (dlist_insert i) ~-1://# /=. | ||
by rewrite dmap_comp &(eq_dmap) => -[v vs]. | ||
qed. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma dmatrix_rows_i (j : int) (d : R distr) (r c : int) : | ||
0 <= c => 0 <= r => 0 <= j < r => | ||
dmatrix d r c = | ||
dmap | ||
(dvector d c `*` dlist (dvector d c) (r-1)) | ||
(fun r_rs : _ * _ => trmx (ofcols c r (insert r_rs.`1 r_rs.`2 j))). | ||
proof. | ||
move=> ge0_c ge0_r rgj; rewrite dmatrix_rows //. | ||
rewrite {1}(_ : r = (r - 1) + 1) // (dlist_insert j) ~-1://# /=. | ||
by rewrite dmap_comp &(eq_dmap) => -[v vs]. | ||
qed. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma col_ofcols (i r c : int) (vs : vector list) : | ||
0 <= r => 0 <= c => 0 <= i < c | ||
=> size vs = c | ||
=> all (fun v : vector => size v = r) vs | ||
=> col (ofcols r c vs) i = vs.[i]. | ||
proof. | ||
move=> ge0_r ge0_c rgi sz_vs /allP => sz_in_vs. | ||
have sz_rows: rows (ofcols r c vs) = r. | ||
- by rewrite rows_offunm lez_maxr // sz_in_vs. | ||
apply/eq_vectorP; split=> /=. | ||
- by rewrite sz_rows sz_in_vs // &(mem_nth) sz_vs. | ||
by move=> j; rewrite sz_rows => rgj; rewrite get_offunm //#. | ||
qed. | ||
|
||
(* -------------------------------------------------------------------- *) | ||
lemma dmatrixM (m n p : int) (d1 d2 : R distr) : | ||
0 <= m => 0 <= n => 0 <= p => | ||
|
||
let d = | ||
dlet (dmatrix d1 m n) (fun (m1 : matrix) => | ||
dmap (dmatrix d2 n p) (fun (m2 : matrix) => m1 * m2)) in | ||
|
||
forall i j, 0 <= i < m => 0 <= j < p => | ||
dmap d (fun m => m.[i, j]) = | ||
((weight d1) ^ (n * (m-1)) * (weight d2) ^ (n * (p-1))) \cdot dmul n d1 d2. | ||
proof. | ||
move=> ge0_m ge0_n ge0_p d i j rg_i rg_j. | ||
have [gt0_m gt0_p]: (0 <= m-1) /\ (0 <= p-1) by smt(). | ||
rewrite /d (dmatrix_rows_i i) //= (dmatrix_cols_i j) //=. | ||
pose D1 := dvector d1 n `*` _; pose D2 := dvector d2 n `*` _. | ||
pose F1 := fun (r_rs : _ * _) => trmx (ofcols n m (insert r_rs.`1 r_rs.`2 i)). | ||
pose F2 := fun (c_cs : _ * _) => ofcols n p (insert c_cs.`1 c_cs.`2 j). | ||
pose F r rs c cs := (trmx (ofcols n m (insert r rs i)) * ofcols n p (insert c cs j)).[i, j]. | ||
pose D := dlet D1 (fun c : _ * _ => dmap D2 (fun r : _ * _ => F c.`1 c.`2 r.`1 r.`2)). | ||
apply: (eq_trans _ D) => @/D => {D}. | ||
- rewrite dmap_dlet dlet_dmap /= &(eq_dlet) // => ? /=. | ||
by rewrite 2!dmap_comp &(eq_dmap). | ||
pose G (x_xs : (_ * _) * (_ * _)) := F x_xs.`1.`1 x_xs.`2.`1 x_xs.`1.`2 x_xs.`2.`2. | ||
rewrite L2 /= => {D1 D2}; pose D1 := _ `*` _; pose D2 := _ `*` _. | ||
have @/G /= <- := dmap_dprodE D1 D2 G => {G}. | ||
pose G (vs : vector * vector) := dotp vs.`1 vs.`2. | ||
apply: (eq_trans _ (dmap (D1 `*` D2) (fun x : _ * _ => G x.`1))). | ||
- apply: eq_dmap_in=> -[[c r] [cs rs]] @/G @/F /=. | ||
case/supp_dprod=> /= /supp_dprod[/=]. | ||
case/(supp_dvector _ _ _ ge0_n) => sz_c _. | ||
case/(supp_dvector _ _ _ ge0_n) => sz_r _. | ||
move/supp_dprod=> [/=]. | ||
case/(supp_dlist _ _ _ gt0_m) => [sz_cs /allP sz_in_cs]. | ||
case/(supp_dlist _ _ _ gt0_p) => [sz_rs /allP sz_in_rs]. | ||
rewrite get_mulmx row_trmx /= !col_ofcols //. | ||
- by rewrite size_insert ?sz_cs //#. | ||
- apply/allP=> v /mem_insert [->>|] //=. | ||
by move/sz_in_cs => /(supp_dvector _ _ _ ge0_n). | ||
- by rewrite size_insert ?sz_rs //#. | ||
- apply/allP=> v /mem_insert [->>|] //=. | ||
by move/sz_in_rs => /(supp_dvector _ _ _ ge0_n). | ||
by rewrite !nth_insert // (sz_cs, sz_rs) //#. | ||
rewrite foo /D2 weight_dprod !weight_dlist // !weight_dmap. | ||
rewrite !weight_dlist ?lez_maxr // -!exprM. | ||
congr=> @/D1 @/G => {D1 D2 G} @/dmul. | ||
rewrite !dmap_dprodE /= dlet_dmap lez_maxr //. | ||
apply/in_eq_dlet => //= xs /(supp_dlist _ _ _ ge0_n)[sz_xs _]. | ||
rewrite dmap_comp lez_maxr //; apply/eq_dmap_in => /= ys. | ||
case/(supp_dlist _ _ _ ge0_n)=> sz_ys _ @/dotp. | ||
rewrite !size_oflist sz_xs sz_ys lez_maxr //. | ||
apply/eq_sym; rewrite (Big.BAdd.big_nth witness) predT_comp. | ||
rewrite size_zip sz_xs sz_ys lez_minr //. | ||
rewrite !Big.BAdd.big_seq /= &(Big.BAdd.eq_bigr) /=. | ||
move=> k /mem_range rg_k; rewrite !(get_oflist witness) ~-1://#. | ||
have := nth_zip witness witness xs ys k _; first by smt(). | ||
by rewrite (nth_change_dfl witness) => [|->//]; rewrite size_zip /#. | ||
qed. |