1 (**************************************************************************)
2 (* *)
3 (* OCaml *)
4 (* *)
5 (* Pierre Chambart, OCamlPro *)
6 (* Mark Shinwell and Leo White, Jane Street Europe *)
7 (* *)
8 (* Copyright 2013--2016 OCamlPro SAS *)
9 (* Copyright 2014--2016 Jane Street Group LLC *)
10 (* *)
11 (* All rights reserved. This file is distributed under the terms of *)
12 (* the GNU Lesser General Public License version 2.1, with the *)
13 (* special exception on linking described in the file LICENSE. *)
14 (* *)
15 (**************************************************************************)
16
17 module type Thing = sig
18 type t
19
20 include Hashtbl.HashedType with type t := t
21 include Map.OrderedType with type t := t
22
23 val output : out_channel -> t -> unit
24 val print : Format.formatter -> t -> unit
25 end
26
27 module type Set = sig
28 module T : Set.OrderedType
29 include Set.S
30 with type elt = T.t
31 and type t = Set.Make (T).t
32
33 val output : out_channel -> t -> unit
34 val print : Format.formatter -> t -> unit
35 val to_string : t -> string
36 val of_list : elt list -> t
37 val map : (elt -> elt) -> t -> t
38 end
39
40 module type Map = sig
41 module T : Map.OrderedType
42 include Map.S
43 with type key = T.t
44 and type 'a t = 'a Map.Make (T).t
45
46 val filter_map : 'a t -> f:(key -> 'a -> 'b option) -> 'b t
47 val of_list : (key * 'a) list -> 'a t
48
49 val disjoint_union :
50 ?eq:('a -> 'a -> bool) -> ?print:(Format.formatter -> 'a -> unit) -> 'a t ->
51 'a t -> 'a t
52
53 val union_right : 'a t -> 'a t -> 'a t
54
55 val union_left : 'a t -> 'a t -> 'a t
56
57 val union_merge : ('a -> 'a -> 'a) -> 'a t -> 'a t -> 'a t
58 val rename : key t -> key -> key
59 val map_keys : (key -> key) -> 'a t -> 'a t
60 val keys : 'a t -> Set.Make(T).t
61 val data : 'a t -> 'a list
62 val of_set : (key -> 'a) -> Set.Make(T).t -> 'a t
63 val transpose_keys_and_data : key t -> key t
64 val transpose_keys_and_data_set : key t -> Set.Make(T).t t
65 val print :
66 (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit
67 end
68
69 module type Tbl = sig
70 module T : sig
71 type t
72 include Map.OrderedType with type t := t
73 include Hashtbl.HashedType with type t := t
74 end
75 include Hashtbl.S
76 with type key = T.t
77 and type 'a t = 'a Hashtbl.Make (T).t
78
79 val to_list : 'a t -> (T.t * 'a) list
80 val of_list : (T.t * 'a) list -> 'a t
81
82 val to_map : 'a t -> 'a Map.Make(T).t
83 val of_map : 'a Map.Make(T).t -> 'a t
84 val memoize : 'a t -> (key -> 'a) -> key -> 'a
85 val map : 'a t -> ('a -> 'b) -> 'b t
86 end
87
88 module Pair (A : Thing) (B : Thing) : Thing with type t = A.t * B.t = struct
89 type t = A.t * B.t
90
91 let compare (a1, b1) (a2, b2) =
92 let c = A.compare a1 a2 in
93 if c <> 0 then c
94 else B.compare b1 b2
95
96 let output oc (a, b) = Printf.fprintf oc " (%a, %a)" A.output a B.output b
97 let hash (a, b) = Hashtbl.hash (A.hash a, B.hash b)
98 let equal (a1, b1) (a2, b2) = A.equal a1 a2 && B.equal b1 b2
99 let print ppf (a, b) = Format.fprintf ppf " (%a, @ %a)" A.print a B.print b
100 end
101
102 module Make_map (T : Thing) = struct
103 include Map.Make (T)
104
105 let filter_map t ~f =
106 fold (fun id v map ->
107 match f id v with
108 | None -> map
109 | Some r -> add id r map) t empty
110
111 let of_list l =
112 List.fold_left (fun map (id, v) -> add id v map) empty l
113
114 let disjoint_union ?eq ?print m1 m2 =
115 union (fun id v1 v2 ->
116 let ok = match eq with
117 | None -> false
118 | Some eq -> eq v1 v2
119 in
120 if not ok then
121 let err =
122 match print with
123 | None ->
124 Format.asprintf "Map.disjoint_union %a" T.print id
125 | Some print ->
126 Format.asprintf "Map.disjoint_union %a => %a <> %a"
127 T.print id print v1 print v2
128 in
129 Misc.fatal_error err
130 else Some v1)
131 m1 m2
132
133 let union_right m1 m2 =
134 merge (fun _id x y -> match x, y with
135 | None, None -> None
136 | None, Some v
137 | Some v, None
138 | Some _, Some v -> Some v)
139 m1 m2
140
141 let union_left m1 m2 = union_right m2 m1
142
143 let union_merge f m1 m2 =
144 let aux _ m1 m2 =
145 match m1, m2 with
146 | None, m | m, None -> m
147 | Some m1, Some m2 -> Some (f m1 m2)
148 in
149 merge aux m1 m2
150
151 let rename m v =
152 try find v m
153 with Not_found -> v
154
155 let map_keys f m =
156 of_list (List.map (fun (k, v) -> f k, v) (bindings m))
157
158 let print f ppf s =
159 let elts ppf s = iter (fun id v ->
160 Format.fprintf ppf "@ (@[%a@ %a@])" T.print id f v) s in
161 Format.fprintf ppf "@[<1>{@[%a@ @]}@]" elts s
162
163 module T_set = Set.Make (T)
164
165 let keys map = fold (fun k _ set -> T_set.add k set) map T_set.empty
166
167 let data t = List.map snd (bindings t)
168
169 let of_set f set = T_set.fold (fun e map -> add e (f e) map) set empty
170
171 let transpose_keys_and_data map = fold (fun k v m -> add v k m) map empty
172 let transpose_keys_and_data_set map =
173 fold (fun k v m ->
174 let set =
175 match find v m with
176 | exception Not_found ->
177 T_set.singleton k
178 | set ->
179 T_set.add k set
180 in
181 add v set m)
182 map empty
183 end
184
185 module Make_set (T : Thing) = struct
186 include Set.Make (T)
187
188 let output oc s =
189 Printf.fprintf oc " ( ";
190 iter (fun v -> Printf.fprintf oc "%a " T.output v) s;
191 Printf.fprintf oc ")"
192
193 let print ppf s =
194 let elts ppf s = iter (fun e -> Format.fprintf ppf "@ %a" T.print e) s in
195 Format.fprintf ppf "@[<1>{@[%a@ @]}@]" elts s
196
197 let to_string s = Format.asprintf "%a" print s
198
199 let of_list l = match l with
200 | [] -> empty
201 | [t] -> singleton t
202 | t :: q -> List.fold_left (fun acc e -> add e acc) (singleton t) q
203
204 let map f s = of_list (List.map f (elements s))
205 end
206
207 module Make_tbl (T : Thing) = struct
208 include Hashtbl.Make (T)
209
210 module T_map = Make_map (T)
211
212 let to_list t =
213 fold (fun key datum elts -> (key, datum)::elts) t []
214
215 let of_list elts =
216 let t = create 42 in
217 List.iter (fun (key, datum) -> add t key datum) elts;
218 t
219
220 let to_map v = fold T_map.add v T_map.empty
221
222 let of_map m =
223 let t = create (T_map.cardinal m) in
224 T_map.iter (fun k v -> add t k v) m;
225 t
226
227 let memoize t f = fun key ->
228 try find t key with
229 | Not_found ->
230 let r = f key in
231 add t key r;
232 r
233
234 let map t f =
235 of_map (T_map.map f (to_map t))
236 end
237
238 module type S = sig
239 type t
240
241 module T : Thing with type t = t
242 include Thing with type t := T.t
243
244 module Set : Set with module T := T
245 module Map : Map with module T := T
246 module Tbl : Tbl with module T := T
247 end
248
249 module Make (T : Thing) = struct
250 module T = T
251 include T
252
253 module Set = Make_set (T)
254 module Map = Make_map (T)
255 module Tbl = Make_tbl (T)
256 end
257