-
Notifications
You must be signed in to change notification settings - Fork 0
/
category.ml
1218 lines (886 loc) · 30.7 KB
/
category.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
(**
This module introduces several categorical concepts and instances.
In this module, we translate the type classes and type class
instances of the following papers
"The Simple Essence of Automatic Differentiation" (ICFP'2018)
"Compiling to categories" (ICFP'2017)
by Conal Elliot
in OCaml using its module system.
*)
(**
What is the idea of the encoding?
We translate type classes as module types, type class instances
as module definitions and dictionaries as modules. Functors naturally
appear when an instance definition depends on another one.
There is some freedom while translating type classes as modules.
For instance, a type class inheritance relation like:
[class A => B where D]
can be translated by:
[module type B = sig
include A
D
end]
or by:
[module type B = sig
module A : A
D
end]
Since this development contains a long chain of concepts related
by inheritance, we had preferred (when possible) the first kind
of translation to reduce module nesting.
*)
(**
What is challenging?
The challenge is threefold:
1. The inference engine of Haskell automatically elaborates the
type class dictionaries the program needs given the methods it uses
; there is no such elaboration in OCaml. Therefore, we must
elaborate these type class dictionaries by hand before turning them
into modules. This is actually a good thing to help us understand
precisely what is going on.
2. Conal Eliott is using a notion of *constrained* type class. A
type class is constrained if it contains an associated constraint
which must be satisfied to apply any method of this type
class. (Typically, in the paper, this associated constraint help to
restrict the types of objects valid for a given category.) We
encode an associated constraint "Ok a" using a proof-term of the
form "'a ok" which must be explicitly provided as an argument to
the encoded method.
3. For the sake of simplicity, the paper does not always provide
the exact definitions that make the whole development work. For
instance, the aforementioned associated constraint is not included
in the initial definition of categories section 4.1 while it is
actually needed to introduce the category of additive functions.
Fortunately for you, we already solved this problem as the
important definitions are given here with all the necessary
details omitted in the paper. As a consequence, when the paper
and the code disagree, give more credit to the code.
*)
(**
What is your job here?
1. Read the following papers:
- "The Simple Essence of Automatic Differentiation" (ICFP'2018)
- "Compiling to categories" (ICFP'2017)
and optionally (but may be easier to start with):
- "Beautiful differentiation" (ICFP'2009)
2. Complete the code by replacing 'failwith "Student! This is your job!"'
or `todo` by the actual definition. While completing the code, we
suggest you run `make -C tests/task-1` to get some feedback about
your progress.
*)
(**
A category comprises objects and morphisms.
In this paper, the objects are the programming language types and
the morphisms are the inhabitants of [('a, 'b) k].
A category has a distinguised identity morphism named [id] as well
as a composition operator. [compose] is associative and admits [id]
as left and right neutral.
As said earlier, we actually constrain the objects to be the types
[a] that satisfy some constraint [C] over types. We encode this
constraint using a type constructor named [ok], i.e [C (a)] is
satisfied iff [a ok] is inhabited. We will usually assume that [ok]
is stable by type applications (as witnessed by [ok_pair] and
[ok_arrow] that will come in forthcoming definitions).
*)
module type Cat = sig
type ('a, 'b) k
type 'a ok
val id: 'a ok -> ('a, 'a) k
val compose: 'a ok -> 'b ok -> 'c ok ->
('b, 'c) k -> ('a, 'b) k -> ('a, 'c) k
end
(**
An initial (or universal or coterminal) object [unit] of a category
is such that for every object ['a] there is exactly one morphism
[it] from [t] to ['a].
*)
module type HasInitialObject = sig
type ('a, 'b) k
type 'a ok
val ti : 'a ok -> (unit, 'a) k
end
(**
A terminal (or final) object [unit] of a category is such that for
every object ['a] there is exactly one morphism [ti] from ['a] to
[t].
*)
module type HasTerminalObject = sig
type ('a, 'b) k
type 'a ok
val it : 'a ok -> ('a, unit) k
end
(**
A unit arrow starts from the terminal object.
*)
module type HasConstantArrows = sig
type ('a, 'b) k
type 'a ok
val unit_arrow : 'a ok -> 'a -> (unit, 'a) k
end
(**
*)
(**
A monoidal category allows us to form morphisms between products.
For the sake of simplicity and even though it restricts the
generality of the definition, we do not introduce a new type
constructor for products but we reuse instead the standard type
constructor [*].
Notice that we [include] the module type signature [Cat] in
[MonoidalCat] to encode the fact that a [MonoidalCat] is a
[Cat]. Notice also that we apply a destructive substitution of the
form (... with type ('a, 'b) k := ...) on the included signature
to share the type constructor for morphisms between module type
signatures.
*)
module type MonoidalCat = sig
type ('a, 'b) k
include Cat with type ('a, 'b) k := ('a, 'b) k
val ok_pair: 'a ok -> 'b ok -> ('a * 'b) ok
val pair : 'a ok -> 'b ok -> 'c ok -> 'd ok ->
('a, 'c) k -> ('b, 'd) k -> ('a * 'b, 'c * 'd) k
end
(**
A cartesian category is a monoidal category whose product is
categorical, that it is equipped with two projections [exl] and
[exr] and verifies the universal property for products.
A cartesian category has a terminal object. As before, we will
reuse the type [unit] to denote all terminal objects, even though
this restricts a bit the definition. We include in that definition
the existence of "constant arrows".
Finally, [dup] witnesses the fact that a cartesian category has
diagonal maps. This operator allows for data duplication.
*)
module type CartesianCat = sig
type ('a, 'b) k
include MonoidalCat with type ('a, 'b) k := ('a, 'b) k
val exl : 'a ok -> 'b ok -> ('a * 'b, 'a) k
val exr : 'a ok -> 'b ok -> ('a * 'b, 'b) k
val ok_unit : unit ok
include HasTerminalObject
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
include HasConstantArrows
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
val dup : 'a ok -> ('a, 'a * 'a) k
end
(**
A cartesian closed category is a cartesian category containing
exponential objects with the operations [apply], [curry] and
[uncurry].
Once again, we reuse the arrow of OCaml to represent the
exponential ['a -> 'b] of ['b] by ['a] even though the definition
could be made more general.
To handle constants, we require the presence of an initial object.
*)
module type CartesianClosedCat = sig
type ('a, 'b) k
include CartesianCat with type ('a, 'b) k := ('a, 'b) k
val ok_arrow: 'a ok -> 'b ok -> ('a -> 'b) ok
val apply : 'a ok -> 'b ok ->
(('a -> 'b) * 'a, 'b) k
val curry : 'a ok -> 'b ok -> 'c ok ->
('a * 'b, 'c) k -> ('a, 'b -> 'c) k
val uncurry : 'a ok -> 'b ok -> 'c ok ->
('a, 'b -> 'c) k -> ('a * 'b, 'c) k
end
(**
Simply typed lambda-calculus is a Cartesian Closed Category.
*)
module LambdaCat
: CartesianClosedCat
with type ('a, 'b) k = 'a -> 'b
with type 'a ok = unit
= struct
(** In STLC, the morphisms are the functions. *)
type ('a, 'b) k = 'a -> 'b
(** There is no restriction over the type of objects. *)
type 'a ok = unit
let ok_arrow () () = ()
let ok_pair () () = ()
(*-------------------------------------------------*)
(* Question 1: Complete the following definitions. *)
(* Then, run `make -C tests/task-1`. *)
(*-------------------------------------------------*)
(* DONE : Quit straightforward ! *)
let id () =
fun x -> x
let compose () () () f g =
fun x -> x |> g |> f
let pair () () () () f g =
fun (x,y) -> f x, g y
let exl () () (x, y) = x
let exr () () (x, y) = y
let dup () x = (x, x)
let apply () () (f, x) = f x
let curry () () () f =
fun x y -> f (x, y)
let uncurry () () () f =
fun (x, y) -> f x y
(** STLC has a terminal object. *)
let it () x = ()
(** STLC has constant arrows. *)
let unit_arrow () x =
fun () -> x
let ok_unit = ()
end
(** We will use composition in LambdaCat many times, hence we
introduce an infix operator to write composition more concisely. *)
let ( ** ) f g = LambdaCat.compose () () () f g
(**
A co-cartesian category is a monoidal category with
a left-injection [inl], a right-injection [inr] and
an operation [jam] which mirrors the duplication
of cartesian categories.
A co-cartesian category has an initial object.
*)
module type CoCartesianCat = sig
type ('a, 'b) k
include MonoidalCat with type ('a, 'b) k := ('a, 'b) k
val inl : 'a ok -> 'b ok -> ('a, 'a * 'b) k
val inr : 'a ok -> 'b ok -> ('b, 'a * 'b) k
val jam : 'a ok -> ('a * 'a, 'a) k
include HasInitialObject
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
end
(**
The following set of derived operations over Cartesian
and CoCartesian categories will be helpful.
*)
module CartesianCatDerivedOperations (C : CartesianCat) = struct
let fork oka okc okd f g =
C.(compose oka (ok_pair oka oka) (ok_pair okc okd)
(pair oka oka okc okd f g) (dup oka)
)
let unfork oka okc okd h =
C.((compose oka (ok_pair okc okd) okc (exl okc okd) h),
(compose oka (ok_pair okc okd) okd (exr okc okd) h))
end
module CoCartesianCatDerivedOperations (C : CoCartesianCat) = struct
let join oka okc okd (f, g) =
C.(compose (ok_pair okc okd) (ok_pair oka oka) oka
(jam oka) (pair okc okd oka oka f g))
let unjoin oka okc okd h =
C.((compose okc (ok_pair okc okd) oka h (inl okc okd)),
(compose okd (ok_pair okc okd) oka h (inr okc okd)))
end
(**
A type [t] is additive if there exists an addition over values of
type [t] and a zero for this addition.
*)
module type Additive = sig
type t
val zero : t
val add : t -> t -> t
end
(**
Unit is additive.
*)
module AdditiveUnit
: Additive with type t = unit
= struct
type t = unit
let zero = ()
let add () () = ()
end
(**
Float is additive.
*)
module AdditiveFloat
: Additive with type t = float
= struct
type t = float
let zero = 0.
let add x y = x +. y
end
(**
If [A] and [B] are additive, so is [A x B].
*)
module AdditivePair (AddA : Additive) (AddB : Additive)
: Additive with type t = AddA.t * AddB.t
= struct
type t = AddA.t * AddB.t
let zero = (AddA.zero, AddB.zero)
let add x y = (AddA.add (fst x) (fst y), AddB.add (snd x) (snd y))
end
(**
If [A] and [B] are additive then [A -> B] is additive too.
*)
module AdditiveLambda (A : sig type t end) (AddB : Additive)
: Additive with type t = A.t -> AddB.t
= struct
type t = A.t -> AddB.t
let zero = fun x -> AddB.zero
let add f g = fun x -> AddB.add (f x) (g x)
end
(**
A function [f] is additive if [f (x + y) = f x + f y].
*)
type ('a, 'b) additive_function =
AdditiveFun of ('a -> 'b)
(**
Additive types naturally provide injection and jam.
Notice the usage of first-class modules here: additive
structures are passed as arguments to these functions.
This flexibility is needed to define the next category.
*)
let inlF (type b) (module AddB : Additive with type t = b) =
fun x -> (x, AddB.zero)
let inrF (type a) (module AddA : Additive with type t = a) =
fun x -> (AddA.zero, x)
let jamF (type a) (module AddA : Additive with type t = a) =
fun (x, y) -> AddA.add x y
(**
Additive functions are a cartesian and co-cartesian category.
*)
module AdditiveFunctionsCat : sig
type ('a, 'b) k = ('a, 'b) additive_function
type 'a ok = (module Additive with type t = 'a)
include CartesianCat
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
include CoCartesianCat
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
include CartesianClosedCat
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
end
= struct
type ('a, 'b) k = ('a, 'b) additive_function
type 'a ok = (module Additive with type t = 'a)
(*-------------------------------------------------*)
(* Question 2: Complete the following definitions. *)
(* Then, run `make -C tests/task-1`. *)
(*-------------------------------------------------*)
let ok_pair (type a b) (oka : a ok) (okb : b ok) : (a * b) ok =
let module Oka = (val oka: Additive with type t = a) in
let module Okb = (val okb: Additive with type t = b) in
(module AdditivePair (Oka) (Okb): Additive with type t = a * b)
let id oka =
AdditiveFun (fun x -> x)
let compose oka okb okc (AdditiveFun f) (AdditiveFun g) =
AdditiveFun (fun x -> x |> g |> f)
let pair oka okb okc okd (AdditiveFun f) (AdditiveFun g) =
AdditiveFun (fun (x, y) -> f x, g y)
let exl oka okb =
AdditiveFun fst
let exr oka okb =
AdditiveFun snd
let dup oka =
AdditiveFun (fun x -> x,x)
(** [inl], [inr] & [jam] definitions use the [*F] functions defined above *)
let inl (type a b) (oka : a ok) (okb : b ok) =
AdditiveFun (inlF okb)
let inr (type a b) (oka : a ok) (okb : b ok) =
AdditiveFun (inrF oka)
let jam (type a) (oka : a ok) =
AdditiveFun (jamF oka)
let ti (type a) (module AddA : Additive with type t = a) =
AdditiveFun (fun () -> AddA.zero)
let it (type a) (module AddA : Additive with type t = a) =
AdditiveFun (fun _ -> ())
let unit_arrow (type a) (module AddA : Additive with type t = a) x =
AdditiveFun (fun () -> x)
let ok_unit : unit ok =
let module OkU = struct
type t = unit
let zero = ()
let add () () = ()
end in
(module OkU : Additive with type t = unit)
let apply oka okb =
AdditiveFun (fun (f, x) -> f x)
let curry oka okb okc (AdditiveFun f) =
AdditiveFun (fun x -> (fun y -> f (x, y)))
let uncurry oka okb okc (AdditiveFun f) =
AdditiveFun (fun (x, y) -> f x y)
let ok_arrow (type a b) oka okb =
let module Oka = (val oka: Additive with type t = a) in
let module Okb = (val okb: Additive with type t = b) in
(module AdditiveLambda (Oka) (Okb): Additive with type t = a -> b)
end
(**
Since we will differentiate numerical function, we need
a type for numeric values.
*)
module type Num = sig
type t
val neg : t -> t
val add : t -> t -> t
val zero : t
val mul : t -> t -> t
end
(**
Float is a famous example of such a type.
*)
module NumFloat : Num with type t = float = struct
type t = float
let neg x = -.x
let zero = 0.
let add x y = x +. y
let mul x y = x *. y
end
(**
A category can be equipped with morphisms that denote numerical
operations.
*)
module type NumCat = sig
type ('a, 'b) k
type t
val negC : (t, t) k
val addC : (t * t, t) k
val mulC : (t * t, t) k
end
(**
We also need some primitive numerical functions.
*)
module type Floating = sig
type t
val sin : t -> t
val cos : t -> t
val exp : t -> t
val inv : t -> t
end
(**
These primitive functions are builtin in OCaml.
*)
module FloatingFloat : Floating with type t = float
= struct
type t = float
let sin = Pervasives.sin
let cos = Pervasives.cos
let exp = Pervasives.exp
let inv x = 1. /. x
end
(**
Again, we transfer these primitive functions in
the categorical world.
*)
module type FloatingCat = sig
type ('a, 'b) k
type t
val sinC : (t, t) k
val cosC : (t, t) k
val expC : (t, t) k
val invC : (t, t) k
end
(**
Given a numerical type, LambdaCat can be extended with the
morphisms for numerical operations.
*)
module LambdaNumCatFromNum (Num : Num)
: NumCat
with type ('a, 'b) k = 'a -> 'b
with type t = Num.t
= struct
type ('a, 'b) k = 'a -> 'b
type t = Num.t
let negC = Num.neg
let addC = LambdaCat.uncurry () () () Num.add
let mulC = LambdaCat.uncurry () () () Num.mul
end
(**
Idem for additive functions.
*)
module AdditiveFunctionNumCatFromNum (Num : Num)
: NumCat
with type ('a, 'b) k = ('a, 'b) additive_function
with type t = Num.t
= struct
type ('a, 'b) k = ('a, 'b) additive_function
type t = Num.t
let negC = AdditiveFun Num.neg
let addC = AdditiveFun (LambdaCat.uncurry () () () Num.add)
let mulC = AdditiveFun (LambdaCat.uncurry () () () Num.mul)
end
module FloatLambdaCat = struct
include LambdaCat
let ok_float = ()
include (LambdaNumCatFromNum (NumFloat) :
NumCat
with type ('a, 'b) k := ('a, 'b) k
with type t = float)
let cosC = FloatingFloat.cos
let sinC = FloatingFloat.sin
let expC = FloatingFloat.exp
let invC = FloatingFloat.inv
end
(**
A final notion: a category is scalable if for each
scalar, there is a scaling morphism.
*)
module type Scalable = sig
type ('a, 'b) k
type t
val scale : t -> (t, t) k
end
(**
The category of additive functions is scalable.
*)
module ScalableAdditiveFromNum (Num : Num)
: Scalable
with type t = Num.t
with type ('a, 'b) k = ('a, 'b) additive_function
= struct
type ('a, 'b) k = ('a, 'b) additive_function
type t = Num.t
let scale x = AdditiveFun (fun dx -> Num.mul dx x)
end
(**
The "good" categories for automatic differentiation are categories
which are both cartesian and cocartesian. Let us introduce a
module signature for them.
*)
module type GoodCat = sig
include CartesianCat
include CoCartesianCat
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
end
(**
From a cartesian and cocartesian category [C], we construct a
cartesian and cocartesian category where morphisms are
differentiable functions.
Given an input [x], a differentiable function [f] not only produces
an output [y] but also a derivative [df] of [f] which is represented
by a morphism of C.
As explained in the paper, a derivative of f is a linear map f'
such that:
{v
| f (x + e) - (f x + f' x e) |
lim ------------------------------ = 0
e -> o | e |
v}
This definition is general enough to capture a notion of derivatives
over higher-dimensional spaces, typically Rⁿ → Rᵐ.
*)
module DiffCat (C : GoodCat) :
sig
type ('a, 'b) k = D of ('a -> ('b * ('a, 'b) C.k))
type 'a ok = (module Additive with type t = 'a) * 'a C.ok
val linearD : ('a -> 'b) -> ('a, 'b) C.k -> ('a, 'b) k
include GoodCat
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
end
= struct
type ('a, 'b) k = D of ('a -> ('b * ('a, 'b) C.k))
(** [linearD f f'] cnstruit un morphisme de la categorie [DiffCat] en
utilisant :
- [f] Comme fonction qui produit l'output [y]
- [f'] Comme differentielle
*)
let linearD f f' = D (fun x -> (f x, f'))
(** Pour cette categorie la contrainte sur les elements est l'intersection
- de la contrainte sur les objets additifs
- et de la contrainte sur les elements de la categorie [C]
*)
type 'a ok = (module Additive with type t = 'a) * 'a C.ok
(*-------------------------------------------------*)
(* Question 3: Complete the following definitions. *)
(* Then, run `make -C tests/task-1`. *)
(*-------------------------------------------------*)
let ok_pair (type a b) (oka : a ok) (okb : b ok) : (a * b) ok =
let adda, okaC = oka in
let addb, okbC = okb in
let module AddA = (val adda: Additive with type t = a) in
let module AddB = (val addb: Additive with type t = b) in
let module AddAB = AdditivePair (AddA) (AddB) in
let addab = (module AddAB: Additive with type t = (a * b)) in
(addab, C.ok_pair okaC okbC)
let id oka =
(* d'abord on extrait le morphisme de differentielle *)
let df = C.id @@ snd @@ oka in
D (fun x -> x, df)
let compose oka okb okc (D g) (D f) =
D (fun x ->
let y, df_x = f x in
let z, dg_fx = g y in
let dfg = C.compose (snd oka) (snd okb) (snd okc) dg_fx df_x in
z, dfg)
let pair
: type a b c d.
a ok -> b ok -> c ok -> d ok ->
(a, c) k -> (b, d) k -> (a * b, c * d) k
= fun oka okb okc okd (D f) (D g) ->
D (fun (x, y) ->
let z, df_x = f x in
let t, dg_y = g y in
let dfxg = C.pair
(snd oka) (snd okb) (snd okc) (snd okd)
df_x dg_y in
(z, t), dfxg)
let exl oka okb =
D (fun (x, y) ->
let df = C.exl (snd oka) (snd okb) in
x, df)
let dup oka =
D (fun x ->
let df = C.dup (snd oka) in
(x, x), df)
let exr oka okb =
D (fun (x, y) ->
let df = C.exr (snd oka) (snd okb) in
y, df)
let inl oka okb =
let f = inlF (fst okb) in
let df = C.inl (snd oka) (snd okb) in
linearD f df
let inr oka okb =
let f = inrF (fst oka) in
let df = C.inr (snd oka) (snd okb) in
linearD f df
let jam oka =
let f = jamF (fst oka) in
let df = C.jam (snd oka) in
linearD f df
let ti (type a) (((module AddA), coka) : a ok) : (unit, a) k =
let f () = AddA.zero in
let df = C.ti coka in
linearD f df
let it (type a) (oka : a ok) : (a, unit) k =
let f _ = () in
let df = C.it (snd oka) in
linearD f df
let unit_arrow (type a) (oka : a ok) (x: a) : (unit, a) k =
let module Add = (val fst oka: Additive with type t = a) in
let f () = x in
let df = C.unit_arrow (snd oka) Add.zero in
linearD f df
let ok_unit =
((module AdditiveUnit : Additive with type t = unit), C.ok_unit)
end
(**
Automatic differentiation is usually applied to numerical
functions. In that case, the category is typically extended with
numerical primitives and a type [Num.t] for scalars.
*)
module type GoodScalableFloatingCat = sig
include GoodCat
module Num : Num
include NumCat
with type ('a, 'b) k := ('a, 'b) k
with type t = Num.t
include Scalable
with type ('a, 'b) k := ('a, 'b) k
with type t := Num.t
val ok_t : t ok
end
(**
Additive functions over floats are an example of such category.
*)
module GoodScalableFloatingAdditiveFunctionsCat
: GoodScalableFloatingCat
with type ('a, 'b) k = ('a, 'b) AdditiveFunctionsCat.k
with module Num = NumFloat
with type 'a ok = 'a AdditiveFunctionsCat.ok
with type t = NumFloat.t
= struct
include AdditiveFunctionsCat
module Num = NumFloat
include (ScalableAdditiveFromNum(Num)
: Scalable
with type ('a, 'b) k := ('a, 'b) k
with type t = NumFloat.t)
include (AdditiveFunctionNumCatFromNum(Num)
: NumCat
with type ('a, 'b) k := ('a, 'b) k
with type t := NumFloat.t)
let ok_t : t ok = (module Num)
let ok_unit : unit ok = (module AdditiveUnit)
end
(**
Automatic differentiation for functions from float to float.
*)
module DiffNumCat
(C : GoodScalableFloatingCat)
(Floating : Floating with type t = C.t)
: sig
type ('a, 'b) k = D of ('a -> ('b * ('a, 'b) C.k))
type 'a ok = (module Additive with type t = 'a) * 'a C.ok
val linearD : ('a -> 'b) -> ('a, 'b) C.k -> ('a, 'b) k
include GoodCat
with type ('a, 'b) k := ('a, 'b) k
with type 'a ok := 'a ok
include NumCat
with type ('a, 'b) k := ('a, 'b) k
with type t := C.t
include FloatingCat
with type ('a, 'b) k := ('a, 'b) k
with type t := C.t
end
= struct
include DiffCat (C)
(*-------------------------------------------------*)
(* Question 4: Complete the following definitions. *)
(* Then, run `make -C tests/task-1`. *)
(*-------------------------------------------------*)
(* TODO *)
(** Since [negC], [addC] are linear functions, on every points, the