Skip to content

Commit

Permalink
DistrMatrix
Browse files Browse the repository at this point in the history
  • Loading branch information
strub committed Jul 17, 2024
1 parent a333999 commit 2754b01
Showing 1 changed file with 369 additions and 0 deletions.
369 changes: 369 additions & 0 deletions examples/distrmatrix.ec
Original file line number Diff line number Diff line change
@@ -0,0 +1,369 @@
(* -------------------------------------------------------------------- *)
require import AllCore List Distr DList Number StdOrder StdBigop.
require import RealSeries.
require (*--*) DynMatrix.
(*---*) import IntOrder RealOrder RField Bigint Bigreal.

(* -------------------------------------------------------------------- *)
clone import DynMatrix as DM.
(*-*) import DM.ZR.

(* -------------------------------------------------------------------- *)
abbrev "_.[_]" ['a] (xs : 'a list) (i : int) = nth<:'a> witness xs i.

(* -------------------------------------------------------------------- *)
lemma compE ['a 'b 'c] (f : 'a -> 'b) (g : 'b -> 'c) (x : 'a) :
(g \o f) x = g (f x).
proof. done. qed.
hint simplify compE.
(* -------------------------------------------------------------------- *)
lemma dlist_ubound (n : int) (d : R distr) (E : R -> bool) : 0 <= n =>
mu
(dlist d n)
(fun xs => exists i, 0 <= i < n /\ E xs.[i])
<= n%r * mu d E.
proof.
elim: n => /= [|n ge0_n ih]; first by rewrite dlist0 // dunitE //#.
rewrite dlistS //= dmapE /(\o) /=.
pose P1 (x : R) := E x.
pose P2 (xs : R list) := exists i, (0 <= i < n /\ E xs.[i]).
pose P (x_xs : R * R list) := P1 x_xs.`1 \/ P2 x_xs.`2.
rewrite (mu_eq_support _ _ P).
- case=> [x xs] /supp_dprod /= [_].
case/(supp_dlist _ _ _ ge0_n) => [sz_xs _].
rewrite /P /=; apply/eq_iff; split; first smt().
case=> [Ex|]; first exists 0; smt().
by case=> i rg_i; exists (i+1) => //#.
apply: (ler_trans _ _ _ (le_dprod_or _ _ _ _)).
rewrite fromintD mulrDl /= addrC ler_add.
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)).
- by apply: (ler_trans _ _ _ (ler_pimulr _ _ _ _)).
qed.
(* -------------------------------------------------------------------- *)
lemma L ['a 'b 'c 'd 'e 'ab 'ac 'bd 'cd]
(da : 'a distr) (db : 'b distr) (dc : 'c distr) (dd : 'd distr)
(Fab : 'a * 'b -> 'ab)
(Fcd : 'c * 'd -> 'cd)
(F : 'ab -> 'cd -> 'e)
(Fac : 'a * 'c -> 'ac)
(Fbd : 'b * 'd -> 'bd)
(G : 'ac -> 'bd -> 'e)
:
(forall a b c d, F (Fab (a, b)) (Fcd (c, d)) = G (Fac (a, c)) (Fbd (b, d))) =>

dlet
(dmap (da `*` db) Fab)
(fun ab =>
dmap
(dmap (dc `*` dd) Fcd)
(fun cd => F ab cd))
= dlet
(dmap (da `*` dc) Fac)
(fun ac =>
dmap
(dmap (db `*` dd) Fbd)
(fun bd => G ac bd)).
proof.
pose D1 := dlet (da `*` db)
(fun ab => dlet dc (fun c => dmap dd (fun d => F (Fab ab) (Fcd (c, d))))).
move=> eq; rewrite dlet_dmap /= &(eq_trans _ D1) /D1 => {D1}.
- by rewrite &(eq_dlet) // => ab /=; rewrite dmap_comp dmap_dprodE.
pose D2 := dlet (da `*` dc)
(fun ac => dlet db (fun b => dmap dd (fun d => G (Fac ac) (Fbd (b, d))))).
rewrite dlet_dmap /= &(eq_trans _ D2) /D2 => {D2}; last first.
- by rewrite &(eq_dlet) // => ac /=; rewrite dmap_comp dmap_dprodE.
rewrite !dprod_dlet !dlet_dlet /= &(eq_dlet) // => a /=.
rewrite dlet_dlet /= dlet_swap &(eq_dlet) // => b /=.
rewrite 2!(dlet_dunit, dlet_unit) /= dlet_dmap.
rewrite &(eq_dlet) // => c /=; rewrite &(eq_dmap) // => d /=.
by apply: eq.
qed.

(* -------------------------------------------------------------------- *)
lemma L2 ['a 'b 'c 'd 'e]
(da : 'a distr) (db : 'b distr) (dc : 'c distr) (dd : 'd distr)
(F : 'a -> 'b -> 'c -> 'd -> 'e)
:
dlet
(da `*` db)
(fun ab : 'a * 'b =>
dmap
(dc `*` dd)
(fun cd : 'c * 'd => F ab.`1 ab.`2 cd.`1 cd.`2))
= dlet
(da `*` dc)
(fun ac : 'a * 'c =>
dmap
(db `*` dd)
(fun bd : 'b * 'd => F ac.`1 bd.`1 ac.`2 bd.`2)).
proof.
pose F1 (ab : 'a * 'b) (cd : 'c * 'd) := F ab.`1 ab.`2 cd.`1 cd.`2.
pose F2 (ac : 'a * 'c) (bd : 'b * 'd) := F ac.`1 bd.`1 ac.`2 bd.`2.
have := L da db dc dd idfun idfun F1 idfun idfun F2 _; first done.
by rewrite !dmap_id.
qed.
(* -------------------------------------------------------------------- *)
lemma dprod_dunit ['a 'b] (x : 'a) (y : 'b) :
dunit x `*` dunit y = dunit (x, y).
proof.
by apply: eq_distr => -[a b]; rewrite dprod1E !dunit1E /#.
qed.
(* -------------------------------------------------------------------- *)
op dadd (d1 d2 : R distr) =
dmap (d1 `*` d2) (fun xy : R * R => xy.`1 + xy.`2).
(* -------------------------------------------------------------------- *)
lemma dlistD (n : int) (d1 d2 : R distr) : 0 <= n =>
dlet (dlist d1 n) (fun (xs : R list) =>
dmap (dlist d2 n) (fun (ys : R list) =>
mkseq (fun i => xs.[i] + ys.[i]) n))
= dlist (dadd d1 d2) n.
proof.
pose S n (xs ys : R list) := mkseq (fun i => xs.[i] + ys.[i]) n.
pose T n (xs : R list * R list) := S n xs.`1 xs.`2.
move=> ge0_n; rewrite -(dmap_dprodE _ _ (T n)). (* SLOW *)
elim: n ge0_n => /= [|n ge0_n ih]; last rewrite !dlistS //.
- by rewrite !dlist0 // dprod_dunit dmap_dunit /T /S /= mkseq0.
pose C (x_xs : R * R list) := x_xs.`1 :: x_xs.`2.
pose F (x : R * R) (xs : R list * R list) :=
S (n+1) (x.`1 :: xs.`1) (x.`2 :: xs.`2).
pose G (xs ys : R list) := S (n+1) xs ys.
rewrite dmap_dprodE; have -> :=
L d1 (dlist d1 n) d2 (dlist d2 n) C C G idfun idfun F _; first by done.
rewrite !dmap_id /= dmap_dprodE {1}/dadd dlet_dmap.
apply/eq_dlet => // -[x y] /=; rewrite -ih.
rewrite dmap_comp &(eq_dmap) => -[xs ys] /=.
by rewrite /F /S /C /T /= mkseqSr //= &(eq_in_mkseq) //#.
qed.
(* -------------------------------------------------------------------- *)
lemma dlist_dlist ['a] (d : 'a distr) (m n : int) :
0 <= m => 0 <= n =>
dmap (dlist (dlist d m) n) flatten = dlist d (m * n).
proof.
move=> ge0_m; elim: n => /= [|n ge0_n ih].
- by rewrite !dlist0 // dmap_dunit.
rewrite mulrDr /= [dlist d (m * n + m)]dlist_add //.
- by apply: IntOrder.mulr_ge0.
rewrite dlistSr //= dmap_comp !dmap_dprodE /=.
rewrite -ih dlet_dmap /= &(eq_dlet) // => xss /=.
by rewrite &(eq_dmap) => xs /=; rewrite flatten_rcons.
qed.
(* -------------------------------------------------------------------- *)
lemma dmatrix_dlist (r c : int) (d : R distr) :
0 <= r => 0 <= c => dmatrix d r c =
dmap
(dlist d (r * c))
(fun vs => offunm ((fun i j => vs.[j * r + i]), r, c)).
proof.
move=> ge0_r ge0_c @/dmatrix @/dvector.
rewrite dlist_dmap dmap_comp !lez_maxr //.
rewrite -dlist_dlist // dmap_comp &(eq_dmap_in) => xss /=.
case/(supp_dlist _ _ _ ge0_c) => size_xss /allP xssE.
have {xssE} xssE: forall xs, xs \in xss => size xs = r.
- by move=> xs /xssE /(supp_dlist _ _ _ ge0_r).
apply/eq_matrixP=> @/ofcols /= i j [].
rewrite !lez_maxr // => rgi rgj.
rewrite !get_offunm /= ?lez_maxr //.
rewrite (nth_map witness) 1:/#.
rewrite (get_oflist witness) 1:#smt:(mem_nth).
rewrite -nth_flatten ~-1:#smt:(mem_nth); do 2! congr.
rewrite sumzE BIA.big_map predT_comp /(\o) /=.
pose D := BIA.big predT (fun _ => r) (take j xss).
apply: (eq_trans _ D) => @/D.
- rewrite !BIA.big_seq &(BIA.eq_bigr) //=.
by move=> xs /mem_take /xssE.
by rewrite big_constz count_predT size_take //#.
qed.
(* -------------------------------------------------------------------- *)
lemma dmatrixD (r c : int) (d1 d2 : R distr) : 0 <= r => 0 <= c =>
dlet (dmatrix d1 r c) (fun (m1 : matrix) =>
dmap (dmatrix d2 r c) (fun (m2 : matrix) => m1 + m2))
= dmatrix (dadd d1 d2) r c.
proof.
move=> ge0_r ge0_c; rewrite 2?dmatrix_dlist //=.
pose F vs := offunm (fun i j => vs.[j * r + i], r, c).
rewrite dlet_dmap /= dlet_swap dlet_dmap /= dlet_swap /=.
rewrite dmatrix_dlist // -/F -dlistD ~-1:/#.
rewrite dmap_dlet &(eq_dlet) // => xs /=.
rewrite dlet_dunit dmap_comp &(eq_dmap) => ys /=.
apply/eq_matrixP; split.
- by rewrite /F size_addm !size_offunm.
move=> i j []; rewrite rows_addm cols_addm /=.
rewrite !rows_offunm !cols_offunm !maxzz => rgi rgj.
by rewrite get_addm !get_offunm //= nth_mkseq //#.
qed.
(* -------------------------------------------------------------------- *)
op dmul (n : int) (d1 d2 : R distr) =
dmap
(dlist d1 n `*` dlist d2 n)
(fun vs : R list * R list =>
DM.Big.BAdd.big predT
(fun xy : R * R => xy.`1 * xy.`2)
(zip vs.`1 vs.`2)).
(* -------------------------------------------------------------------- *)
lemma foo ['a 'b 'c] (da : 'a distr) (db : 'b distr) (f : 'a -> 'c) :
dmap (da `*` db) (fun ab : 'a * 'b => f ab.`1)
= weight db \cdot dmap da f.
proof. by rewrite dmap_dprodE_swap /= dlet_cst_weight. qed.

(* -------------------------------------------------------------------- *)
lemma dmatrix_cols (d : R distr) (r c : int) : 0 <= c => 0 <= r =>
dmatrix d r c = dmap (dlist (dvector d r) c) (ofcols r c).
proof. by move=> ge0_c ge0_r @/dmatrix; rewrite lez_maxr. qed.

(* -------------------------------------------------------------------- *)
lemma dmatrix_rows (d : R distr) (r c : int) : 0 <= c => 0 <= r =>
dmatrix d r c = dmap (dlist (dvector d c) r) (trmx \o ofcols c r).
proof.
move=> ge0_r ge0_c; rewrite -dmap_comp -dmatrix_cols //.
apply/eq_distr => /= m; rewrite (dmap1E _ trmx).
have ->: pred1 m \o trmx = pred1 (trmx m) by smt(trmxK).
case: (size m = (r, c)); last first.
- by move=> ne_size; rewrite !dmatrix0E //#.
case=> <<- <<-; rewrite -{2}rows_tr -{2}cols_tr !dmatrix1E /=.
by rewrite BRM.exchange_big.
qed.

(* -------------------------------------------------------------------- *)
hint simplify drop0, take0, cats0, cat0s.

(* -------------------------------------------------------------------- *)
(* FIXME: refactor *)

lemma dlist_insert ['a] (i n : int) (d : 'a distr) :
0 <= n => 0 <= i <= n => dlist d (n+1) =
dmap (d `*` dlist d n) (fun x_xs : 'a * 'a list => insert x_xs.`1 x_xs.`2 i).
proof.
move=> ge0_n [ge0_i lti]; apply/eq_sym.
pose f (x_xs : _ * _) := insert x_xs.`1 x_xs.`2 i.
pose g (xs : 'a list) := (xs.[i], take i xs ++ drop (i+1) xs).
have ge0_Sn: 0 <= n + 1 by smt(). apply: (dmap_bij _ _ f g).
- case=> [x xs] /supp_dprod[/=] x_in_d.
case/(supp_dlist _ _ _ ge0_n)=> sz_xs /allP xs_in_d.
move=> @/f /=; apply/supp_dlist; first smt().
rewrite size_insert ?sz_xs //=; apply/allP.
by move=> y /mem_insert[->>//|/xs_in_d].
- move=> xs /(supp_dlist _ _ _ ge0_Sn)[sz_xs /allP xs_in_d] @/g.
rewrite dprod1E !dlist1E ~-1://# sz_xs /=.
rewrite size_cat size_take // size_drop 1:/#.
rewrite iftrue 1:/# -(BRM.big_consT (mu1 d)) &(BRM.eq_big_perm).
by rewrite -cat_cons perm_eq_sym &(perm_eq_nth_take_drop) //#.
- case=> x xs /supp_dprod[/=] _ /(supp_dlist _ _ _ ge0_n)[sz_xs _].
rewrite /g /f /= nth_insert ?sz_xs //= take_insert_le 1:/#.
by rewrite drop_insert_gt 1:/# /= cat_take_drop.
- move=> xs /(supp_dlist _ _ _ ge0_Sn)[/=] sz_xs _ @/f @/g /=.
have sz_take: size (take i xs) = i by rewrite size_take //#.
by apply/insert_nth_take_drop => //#.
qed.

hint simplify insert0.

(* -------------------------------------------------------------------- *)
lemma dmatrix_cols_i (i : int) (d : R distr) (r c : int) :
0 <= c => 0 <= r => 0 <= i < c =>
dmatrix d r c =
dmap
(dvector d r `*` dlist (dvector d r) (c-1))
(fun c_cs : _ * _ => ofcols r c (insert c_cs.`1 c_cs.`2 i)).
proof.
move=> ge0_c ge0_r rgi; rewrite dmatrix_cols //.
rewrite {1}(_ : c = (c - 1) + 1) // (dlist_insert i) ~-1://# /=.
by rewrite dmap_comp &(eq_dmap) => -[v vs].
qed.

(* -------------------------------------------------------------------- *)
lemma dmatrix_rows_i (j : int) (d : R distr) (r c : int) :
0 <= c => 0 <= r => 0 <= j < r =>
dmatrix d r c =
dmap
(dvector d c `*` dlist (dvector d c) (r-1))
(fun r_rs : _ * _ => trmx (ofcols c r (insert r_rs.`1 r_rs.`2 j))).
proof.
move=> ge0_c ge0_r rgj; rewrite dmatrix_rows //.
rewrite {1}(_ : r = (r - 1) + 1) // (dlist_insert j) ~-1://# /=.
by rewrite dmap_comp &(eq_dmap) => -[v vs].
qed.

(* -------------------------------------------------------------------- *)
lemma col_ofcols (i r c : int) (vs : vector list) :
0 <= r => 0 <= c => 0 <= i < c
=> size vs = c
=> all (fun v : vector => size v = r) vs
=> col (ofcols r c vs) i = vs.[i].
proof.
move=> ge0_r ge0_c rgi sz_vs /allP => sz_in_vs.
have sz_rows: rows (ofcols r c vs) = r.
- by rewrite rows_offunm lez_maxr // sz_in_vs.
apply/eq_vectorP; split=> /=.
- by rewrite sz_rows sz_in_vs // &(mem_nth) sz_vs.
by move=> j; rewrite sz_rows => rgj; rewrite get_offunm //#.
qed.

(* -------------------------------------------------------------------- *)
lemma dmatrixM (m n p : int) (d1 d2 : R distr) :
0 <= m => 0 <= n => 0 <= p =>

let d =
dlet (dmatrix d1 m n) (fun (m1 : matrix) =>
dmap (dmatrix d2 n p) (fun (m2 : matrix) => m1 * m2)) in

forall i j, 0 <= i < m => 0 <= j < p =>
dmap d (fun m => m.[i, j]) =
((weight d1) ^ (n * (m-1)) * (weight d2) ^ (n * (p-1))) \cdot dmul n d1 d2.
proof.
move=> ge0_m ge0_n ge0_p d i j rg_i rg_j.
have [gt0_m gt0_p]: (0 <= m-1) /\ (0 <= p-1) by smt().
rewrite /d (dmatrix_rows_i i) //= (dmatrix_cols_i j) //=.
pose D1 := dvector d1 n `*` _; pose D2 := dvector d2 n `*` _.
pose F1 := fun (r_rs : _ * _) => trmx (ofcols n m (insert r_rs.`1 r_rs.`2 i)).
pose F2 := fun (c_cs : _ * _) => ofcols n p (insert c_cs.`1 c_cs.`2 j).
pose F r rs c cs := (trmx (ofcols n m (insert r rs i)) * ofcols n p (insert c cs j)).[i, j].
pose D := dlet D1 (fun c : _ * _ => dmap D2 (fun r : _ * _ => F c.`1 c.`2 r.`1 r.`2)).
apply: (eq_trans _ D) => @/D => {D}.
- rewrite dmap_dlet dlet_dmap /= &(eq_dlet) // => ? /=.
by rewrite 2!dmap_comp &(eq_dmap).
pose G (x_xs : (_ * _) * (_ * _)) := F x_xs.`1.`1 x_xs.`2.`1 x_xs.`1.`2 x_xs.`2.`2.
rewrite L2 /= => {D1 D2}; pose D1 := _ `*` _; pose D2 := _ `*` _.
have @/G /= <- := dmap_dprodE D1 D2 G => {G}.
pose G (vs : vector * vector) := dotp vs.`1 vs.`2.
apply: (eq_trans _ (dmap (D1 `*` D2) (fun x : _ * _ => G x.`1))).
- apply: eq_dmap_in=> -[[c r] [cs rs]] @/G @/F /=.
case/supp_dprod=> /= /supp_dprod[/=].
case/(supp_dvector _ _ _ ge0_n) => sz_c _.
case/(supp_dvector _ _ _ ge0_n) => sz_r _.
move/supp_dprod=> [/=].
case/(supp_dlist _ _ _ gt0_m) => [sz_cs /allP sz_in_cs].
case/(supp_dlist _ _ _ gt0_p) => [sz_rs /allP sz_in_rs].
rewrite get_mulmx row_trmx /= !col_ofcols //.
- by rewrite size_insert ?sz_cs //#.
- apply/allP=> v /mem_insert [->>|] //=.
by move/sz_in_cs => /(supp_dvector _ _ _ ge0_n).
- by rewrite size_insert ?sz_rs //#.
- apply/allP=> v /mem_insert [->>|] //=.
by move/sz_in_rs => /(supp_dvector _ _ _ ge0_n).
by rewrite !nth_insert // (sz_cs, sz_rs) //#.
rewrite foo /D2 weight_dprod !weight_dlist // !weight_dmap.
rewrite !weight_dlist ?lez_maxr // -!exprM.
congr=> @/D1 @/G => {D1 D2 G} @/dmul.
rewrite !dmap_dprodE /= dlet_dmap lez_maxr //.
apply/in_eq_dlet => //= xs /(supp_dlist _ _ _ ge0_n)[sz_xs _].
rewrite dmap_comp lez_maxr //; apply/eq_dmap_in => /= ys.
case/(supp_dlist _ _ _ ge0_n)=> sz_ys _ @/dotp.
rewrite !size_oflist sz_xs sz_ys lez_maxr //.
apply/eq_sym; rewrite (Big.BAdd.big_nth witness) predT_comp.
rewrite size_zip sz_xs sz_ys lez_minr //.
rewrite !Big.BAdd.big_seq /= &(Big.BAdd.eq_bigr) /=.
move=> k /mem_range rg_k; rewrite !(get_oflist witness) ~-1://#.
have := nth_zip witness witness xs ys k _; first by smt().
by rewrite (nth_change_dfl witness) => [|->//]; rewrite size_zip /#.
qed.

0 comments on commit 2754b01

Please sign in to comment.