(* A Huffman tree is either a leaf, which carries a character, or a binary
   node, which carries two sub-trees. Here, we choose NOT to store frequency
   information inside every leaf and node. Instead, we accompany a tree with a
   frequency when we insert it into the priority queue. *)

type tree =
| Leaf of char
| Node of tree * tree

(* We represent a path in a tree as a string that contains only '0' and '1'
   characters. We choose this representation for the sake of simplicity: of
   course, a packed array of booleans would be more compact. *)

type path =
  string

(* By following a specific path in the tree, we end up at a leaf that
   carries a character. This is used while decoding. The function
   returns a pair of the character that was found and an updated
   index into the string [path]. *)

let rec find (path : path) (i : int) (tree : tree) : char * int =
  assert (0 <= i && i <= String.length path);
  match tree with
  | Leaf c ->
      (* We are there. *)
      c, i
  | Node (tree0, tree1) ->
      assert (i < String.length path);    (* [i] should be within bounds *)
      assert (path.[i] = '0' || path.[i] = '1'); (* [s] should be a path *)
      find path (i + 1) (if path.[i] = '0' then tree0 else tree1)

(* We use priority queues whose elements are pairs of a tree and a
   frequency. *)

module Q =
  BinomialQueue.Make(struct
    (* The priority queue contains pairs of a tree and a frequency. *)
    type t =
      tree * int
    (* Elements are compared based on their frequency. In other words,
       drawing an element out of the queue yields an element with least
       frequency. *)
    let compare (_, freq1) (_, freq2) =
      freq1 - freq2
  end)

(* An alphabet maps characters to integer frequencies. *)

type alphabet =
  (char, int) Hashtbl.t

(* Out of an alphabet, we build a tree. *)

let build_tree (alphabet : alphabet) : tree =
  (* Assumption: the alphabet has at least two symbols. *)
  assert (Hashtbl.length alphabet >= 2);
  (* Initialize a priority queue. *)
  let queue : Q.t =
    Hashtbl.fold (fun symbol freq queue ->
      Q.insert (Leaf symbol, freq) queue
    ) alphabet Q.empty 
  in
  (* Process the priority queue. *)
  let rec process (queue : Q.t) : tree =
    (* Assumption: [queue] is nonempty. *)
    assert (not (Q.is_empty queue));
    let (tree0, freq0), queue = Q.extract queue in
    (* If the queue is now empty, we are done. *)
    if Q.is_empty queue then
      tree0
    else
      (* Otherwise, extract another tree. *)
      let (tree1, freq1), queue = Q.extract queue in
      (* Construct a new node, compute its cumulated frequency,
         insert it back into the queue, and continue. *)
      let tree = Node (tree0, tree1) in
      let freq = freq0 + freq1 in
      let queue = Q.insert (tree, freq) queue in
      process queue
  in
  process queue

(* By traversing a tree, one can build a mapping of characters to
   their encodings, which are strings of bits. As we go down, we keep
   track of the path that we have followed into the tree. *)

type cipher_text =
  string

type encoding_dictionary =
  (char, cipher_text) Hashtbl.t

let build_dictionary (tree : tree) : encoding_dictionary =
  let dictionary = Hashtbl.create 256 in
  let rec traverse (path : string) (tree : tree) : unit =
    match tree with
    | Leaf c ->
        Hashtbl.add dictionary c path
    | Node (tree0, tree1) ->
        traverse (path ^ "0") tree0;
        traverse (path ^ "1") tree1
  in
  traverse "" tree;
  dictionary

(* Encoding. *)

let encode_char (dictionary : encoding_dictionary) (c : char) : cipher_text =
  try
    Hashtbl.find dictionary c
  with Not_found ->
    assert false (* unknown character *)

type plain_text =
  string

let encode (dictionary : encoding_dictionary) (text : plain_text) : cipher_text =
  let buffer = Buffer.create 1024 in
  String.iter (fun c ->
    Buffer.add_string buffer (encode_char dictionary c)
  ) text;
  Buffer.contents buffer

(* Decoding. *)

let decode (tree : tree) (text : cipher_text) : plain_text =
  let buffer = Buffer.create 1024 in
  let rec loop i =
    if i = String.length text then
      (* We have reached the end of the text. We are done. *)
      Buffer.contents buffer
    else begin
      (* Decode one more character, and continue. *)
      let c, i = find text i tree in
      Buffer.add_char buffer c;
      loop i
    end
  in
  loop 0

(* Pick new names for the end user. *)

type decoding_dictionary =
  tree

let build_dictionaries alphabet =
  let tree = build_tree alphabet in
  let dictionary = build_dictionary tree in
  dictionary, tree

(* Dumping a tree. *)

let write (tree : tree) : string =
  let buffer = Buffer.create 1024 in
  let rec dump (tree : tree) =
    match tree with
    | Leaf c ->
        Buffer.add_char buffer 'L';
        Buffer.add_char buffer c
    | Node (tree0, tree1) ->
        Buffer.add_char buffer 'N';
        dump tree0;
        dump tree1
  in
  dump tree;
  Buffer.contents buffer

(* Reading back a tree. *)

let read (s : string) : tree =
  let i = ref 0 in
  let get() =
    assert (!i < String.length s);
    let c = s.[!i] in
    incr i;
    c
  in
  let rec read () : tree =
    match get() with
    | 'L' ->
        let c = get() in
        Leaf c
    | 'N' ->
        let tree0 = read() in
        let tree1 = read() in
        Node (tree0, tree1)
    | _ ->
        assert false
  in
  let tree = read() in
  assert (!i = String.length s);
  tree

(* Out of a string, one can build an alphabet. *)

let build_alphabet (text : plain_text) : alphabet =
  let table = Hashtbl.create 256 in
  String.iter (fun symbol ->
    let freq =
      try
        Hashtbl.find table symbol
      with Not_found ->
        0
    in
    Hashtbl.replace table symbol (freq + 1)
  ) text;
  table