(* Most of this file is adapted from OCaml's standard library *)

(***********************************************************************)
(*                                                                     *)
(*                                OCaml                                *)
(*                                                                     *)
(*            Xavier Leroy, projet Cristal, INRIA Rocquencourt         *)
(*                                                                     *)
(*  Copyright 1996 Institut National de Recherche en Informatique et   *)
(*  en Automatique.  All rights reserved.  This file is distributed    *)
(*  under the terms of the GNU Library General Public License, with    *)
(*  the special exception on linking described in file ../LICENSE.     *)
(*                                                                     *)
(***********************************************************************)

type elt = Elt
let compare = fun (x : elt) (y : elt) => 0

type bool = True | False

type t = Empty | Node of t * elt * t * int

let rec invalid_arg : forall 'a. 'a = invalid_arg

(* Sets are represented by balanced binary trees (the heights of the
children differ by at most 2 *)

let height = fun (t : t) => match t with
    Empty => 0
  | Node(x, xx, xxx, h) => h
end

(* Creates a new node with left son l, value v and right son r.
We must have all elements of l < v < all elements of r.
l and r must be balanced and | height l - height r | <= 2.
Inline expansion of height for better speed. *)

let create = fun (l : t) (v : elt) (r : t) =>
  let hl = height l in
  let hr = height r in
  Node(l, v, r, (if gte hl hr then plus hl 1 else plus hr 1))

(* Same as create, but performs one step of rebalancing if necessary.
Assumes l and r balanced and | height l - height r | <= 3.
Inline expansion of create for better speed in the most frequent case
where no rebalancing is required. *)


let bal = fun (l : t) (v : elt) (r : t) =>
  let hl = height l in
  let hr = height r in
      if gt hl (plus hr 2) then
        match l with
          Empty => invalid_arg [t]
        | Node(ll, lv, lr, xxx) =>
            if gte (height ll) (height lr) then
              create ll lv (create lr v r)
            else 
              match lr with
                Empty => invalid_arg [t]
              | Node(lrl, lrv, lrr, xxx)=>
                  create (create ll lv lrl) lrv (create lrr v r)
              end
        end
   else
      if gt hr (plus hl 2)
      then
        match r with
          Empty => invalid_arg [t]
        | Node(rl, rv, rr, xx) =>
            if gte (height rr) (height rl)
            then create (create l v rl) rv rr
            else
              match rl with
                Empty => invalid_arg [t]
              | Node(rll, rlv, rlr, xx) =>
                  create (create l v rll) rlv (create rlr rv rr)
              end
      end 
     else
        Node(l, v, r, (if gte hl hr then plus hl 1 else plus hr 1))


(* Insertion of one element *)

let rec add : elt -> t -> t = fun (x : elt) (s : t)  =>
  match s with
    Empty => Node(Empty, x, Empty, 1)
   | Node(l, v, r, h) =>
      let c = compare x v in
      if eq c 0 then Node(l,v,r,h) else
      if lt c 0 then bal (add x l) v r else bal l v (add x r)
  end

let singleton = fun (x : elt) => Node(Empty, x, Empty, 1)


(* Beware: those two functions assume that the added v is *strictly*
smaller (or bigger) than all the present elements in the tree; it
does not test for equality with the current min (or max) element.
Indeed, they are only used during the "join" operation which
respects this precondition.
*)

let rec add_min_element : elt -> t -> t = fun (v : elt) (s : t) => match s with
  | Empty => singleton v
  | Node (l, x, r, h) =>
    bal (add_min_element v l) x r
end

let rec add_max_element : elt -> t -> t = fun (v : elt) (s : t) => match s with
  | Empty => singleton v
  | Node (l, x, r, h) =>
    bal l x (add_max_element v r)
end

    (* Same as create and bal, but no assumptions are made on the
relative heights of l and r. *)

let rec join : t -> elt -> t -> t = fun (l : t) (v : elt) (r : t) => 
  match l with
  | Empty => add_min_element v r
  | Node(ll,lv,lr,lh) =>
    match r with
    | Empty => add_max_element v l
    | Node(rl,rv,rr,rh) =>
      if gt lh (plus rh 2) then bal ll lv (join lr v r) else
      if gt rh (plus lh 2) then bal (join l v rl) rv rr else
      create l v r
    end
  end
    (* Smallest and greatest element of a set *)

let rec min_elt : t -> elt = fun (s : t) => match s with
    Empty => invalid_arg [elt]
  | Node(l, v, r, yy) => 
     match l with
     | Empty => v
     | Node(x,xx,xxx,xxxx) => min_elt l
     end
end


let rec max_elt : t -> elt = fun (s : t) => match s with
    Empty => invalid_arg [elt]
  | Node(l, v, r, yy) => 
     match r with
     | Empty => v
     | Node(x,xx,xxx,xxxx) => max_elt r
     end
end

(* Remove the smallest element of the given set *)

let rec remove_min_elt : t -> t = fun (s : t) => match s with
    Empty => invalid_arg [t]
  | Node(l, v, r, yy) => 
     match r with
     | Empty => r
     | Node(x,xx,xxx,xxxx) => bal (remove_min_elt l) v r
end end


(* Merge two trees l and r into one.
All elements of l must precede the elements of r.
Assume | height l - height r | <= 2. *)

let merge = fun (t1 t2 : t) => 
  match t1 with
  | Empty => t2
  | Node(xy,yx,xx,yy) =>
    match t2 with
    | Empty => t1
    | Node(xy,yx,xx,yy) => bal t1 (min_elt t2) (remove_min_elt t2)
end
end

(* Merge two trees l and r into one.
All elements of l must precede the elements of r.
No assumption on the heights of l and r. *)

let concat = fun (t1 t2 : t) => 
  match t1 with
  | Empty => t2
  | Node(xy,yx,xx,yy) =>
    match t2 with
    | Empty => t1
    | Node(xy,yx,xx,yy) => join t1 (min_elt t2) (remove_min_elt t2)
end
end

(* Splitting. split x s returns a triple (l, present, r) where
- l is the set of elements of s that are < x
- r is the set of elements of s that are > x
- present is false if s contains no element equal to x,
or true if s contains an element equal to x. *)

type split_set = SplitSet of t * bool * t

let rec split : elt -> t -> split_set = fun (x : elt) (s : t) => match s with
    Empty =>
      SplitSet(Empty, False, Empty)
  | Node(l, v, r, yy) =>
      let c = compare x v in
      if eq c 0 then SplitSet(l, True, r)
      else if lt c 0 then
        match split x l with
        | SplitSet(ll,pres,rl) =>  SplitSet(ll, pres, join rl v r)
        end
      else
        match split x r with
        | SplitSet(lr,pres,rr) => SplitSet (join l v lr, pres, rr)
end
end
(* Implementation of the set operations *)
let empty = Empty

let is_empty = fun (s : t) => match s with Empty => True | Node(xy,yx,xx,yy) => False end

let rec mem : elt -> t -> bool = fun (x : elt) (s : t) => match s with
    Empty => False
  | Node(l, v, r, yy) =>
      let c = compare x v in
      if eq c 0 then True
      else mem x (if lt c 0 then l else r)
end


(* This is a copy of mem to circumvent a restriction of the ornament :
   a function can only be ornamented one in a recursive block *)
let rec mem2 : elt -> t -> bool = fun (x : elt) (s : t) => match s with
    Empty => False
  | Node(l, v, r, yy) =>
      let c = compare x v in
      if eq c 0 then True
      else mem2 x (if lt c 0 then l else r)
end


let rec remove : elt -> t -> t = fun (x : elt) (s : t) => match s with
    Empty => Empty
  | Node(l, v, r, yy) =>
      let c = compare x v in
      if eq c 0 then merge l r else
      if lt c 0 then bal (remove x l) v r else bal l v (remove x r)
end

let rec cardinal : t -> int = fun (s : t) => match s with
    Empty => 0
  | Node(l, v, r, yy) => plus (cardinal l) (plus 1 (cardinal r))
end

type 'a m = MEmpty | MNode of 'a m * elt * 'a * 'a m * int

let rec orn_m_t : forall 'a. 'a m -> t = fun ['a] (m : 'a m) =>
   match m with
   | MEmpty => Empty
   | MNode(l,k,v,r,h) => Node(orn_m_t ['a] l,k,orn_m_t ['a] r,h)
end

type 'a option = Some of 'a | None

let is_some = fun ['a] (x : 'a option) =>
  match x with Some(y) => True | None => False end


type 'a split_map = SplitMap of 'a m * 'a option * 'a m
let rec orn_split : forall 'a. 'a split_map -> split_set = fun ['a] (s : 'a split_map) =>
  match s with
  | SplitMap(l,v,r) =>
   SplitSet(orn_m_t ['a] l,is_some ['a] v,orn_m_t ['a] r)
  end

let forall 'a ornament orn_m_t ['a] : 'a m -> t
and orn_split ['a] : 'a split_map -> split_set
and is_some ['a] : 'a option -> bool

let ornament
    mheight from height with forall +'a. {orn_m_t ['a]} -> _

and mcreate from create with forall +'a. {orn_m_t ['a]} -> _ -> +['a] -> {orn_m_t ['a]} -> {orn_m_t ['a] } 
and mbal from bal with forall +'a. {orn_m_t ['a]} -> _ -> +['a] -> {orn_m_t ['a]} -> {orn_m_t ['a] } 
and madd from add with forall +'a. [elt] -> +['a] -> {orn_m_t ['a]} -> {orn_m_t ['a] } 
and msingleton from singleton with forall +'a. [elt]  -> +['a] -> {orn_m_t ['a]}
and madd_min_element from add_min_element with forall +'a. [elt] -> +['a] -> {orn_m_t ['a]} -> {orn_m_t ['a] }
and madd_max_element from add_max_element with forall +'a. [elt] -> +['a] -> {orn_m_t ['a]} -> {orn_m_t ['a] }
and mjoin from join with forall +'a. {orn_m_t ['a]} -> [elt] -> +['a] -> {orn_m_t ['a]} -> {orn_m_t ['a]}
and mmin_elt from min_elt with forall +'a. {orn_m_t ['a]} -> [elt]
and mmax_elt from max_elt with forall +'a. {orn_m_t ['a]} -> [elt]
and mremove_min_elt from remove_min_elt with forall +'a. {orn_m_t ['a]} -> {orn_m_t ['a]}
and mmerge from merge with forall +'a. {orn_m_t ['a]} -> {orn_m_t ['a]} -> {orn_m_t ['a]}
and mconcat from concat with forall +'a. {orn_m_t ['a]} -> {orn_m_t ['a]} -> {orn_m_t ['a]}
and msplit from split with forall +'a. [elt] -> {orn_m_t ['a]} -> {orn_split ['a]}
and mempty from empty with forall +'a. {orn_m_t ['a]}
and mis_empty from is_empty with forall +'a. {orn_m_t ['a]} -> [bool]
and mmem from mem with forall +'a. [elt] -> {orn_m_t ['a]} -> [bool]
and mfind from mem2 with forall +'a. [elt] -> {orn_m_t ['a]} -> {is_some ['a]}
and mremove from remove with forall +'a. [elt] -> {orn_m_t ['a]} -> {orn_m_t ['a]}
and mcardinal from cardinal with forall +'a. {orn_m_t ['a]} -> [int]