(* -------------------------------------------------------------------------- *)

(* The search algorithm is slightly generalized so as to return the following
   result: either

     Found k
       A complete match was found at position k in the text.

  or

     Interrupted j
       No match was found (i.e., the position n in the text was reached)
       and the current position in the pattern was j, where j < m. *)

type result =
| Found of int
| Interrupted of int

(* This auxiliary function is useful when a result is known in advance to
   be of the second form. *)

let assertInterrupted = function
  | Found _ ->
      assert false
  | Interrupted j ->
      j

(* -------------------------------------------------------------------------- *)

(* Step 2. Transcribe the Java code search01, almost literally, as a tail
   recursive function. Adapt it so as to return a result of the form shown
   above. *)

(* The parameters are:

     pattern, m
       The pattern and its length.
     text, n
       The text and its length. Note that n is not necessarily the actual
       length of the text, but can be smaller. This allows us to search
       just a segment of the text.
     j, k
       The current positions in the pattern and in the text.

   The precondition (i.e. our hypothesis about the parameters) is:

     The indices j and k are in bounds:
       0 <= j <= m
       0 <= k <= n
     The sub-string pattern[0..j) is equal to the sub-string text[k-j..k).

*)

let rec search02 pattern m text n j k : result =
  if j = m then
    Found (k - j)
  else if k = n then
    Interrupted j
  else if pattern.[j] = text.[k] then
    search02 pattern m text n (j + 1) (k + 1)
  else
    search02 pattern m text n 0 (k - j + 1)

(* Termination argument: at each recursive call,
     either k - j increases
     or k - j remains constant and j increases.
   Since k - j is bounded by n and j is bounded by m, the complexity is
   O(mn). *)

(* -------------------------------------------------------------------------- *)

(* Step 3. In the last case, isolate the sub-case where j is zero, which is
   not problematic. In the other sub-case, split the search interval in two
   sub-intervals: up to position k first, then beyond position k. *)

(* Note that the first sub-search cannot possibly succeed, since we are
   searching for a pattern of length m within a segment of length j - 1,
   where j < m. So, it must return Interrupted j', where j' < j. *)

let rec search03 pattern m text n j k : result =
  if j = m then
    Found (k - j)
  else if k = n then
    Interrupted j
  else if pattern.[j] = text.[k] then
    search03 pattern m text n (j + 1) (k + 1)
  else if j = 0 then
    search03 pattern m text n 0 (k + 1)
  else begin
    let j' = assertInterrupted (search03 pattern m text k 0 (k - j + 1)) in
    assert (j' < j);
    search03 pattern m text n j' k
  end

(* -------------------------------------------------------------------------- *)

(* Step 4. Note that the first segment of text is known: according to the
   precondition, text[k-j+1..k) is pattern[1..j). Make this explicit by
   searching within the pattern instead of searching within the text. *)

let rec search04 pattern m text n j k : result =
  if j = m then
    Found (k - j)
  else if k = n then
    Interrupted j
  else if pattern.[j] = text.[k] then
    search04 pattern m text n (j + 1) (k + 1)
  else if j = 0 then
    search04 pattern m text n 0 (k + 1)
  else begin
    let j' = assertInterrupted (search04 pattern m pattern j 0 1) in
    search04 pattern m text n j' k
  end

(* -------------------------------------------------------------------------- *)

(* We have now isolated a call that searches for the pattern within the
   pattern itself. The outcome of this call is independent of text. It
   can be precomputed and stored in a table. *)

(* Step 5. Rewrite the search algorithm under a form that assumes a table is
   given. The table is indexed by j and has the same length m as the
   pattern. Note: since we had j' < j above, we have table.(j) < j. *)

(* Argue that the search now has time complexity O(n). *)

let rec loop table pattern m text n j k : result =
  if j = m then
    Found (k - j)
  else if k = n then
    Interrupted j
  else if pattern.[j] = text.[k] then
    loop table pattern m text n (j + 1) (k + 1)
  else if j = 0 then
    loop table pattern m text n 0 (k + 1)
  else
    loop table pattern m text n table.(j) k

(* Complexity argument: the quantity 2k-j grows strictly at every recursive call.
   Since this quantity is initially 0 and is bounded by 2n, the time complexity
   of loop is O(n). *)

(* -------------------------------------------------------------------------- *)

(* Step 6. Write the table initialization code. *)

(* For 0 < j < m, we need:
     table.(j) = assertInterrupted (search pattern m pattern j 0 1).
   This is the outcome of a search for the whole pattern within pattern[1..j). *)

(* The table can be initialized by the following loop:

     for j = 1 to m - 1 do
       table.(j) <- assertInterrupted (loop table pattern m pattern j 0 1)
     done

   We can use the optimized search function, loop table, even though the table
   is not yet fully initialized, because it will look up the table only at
   positions strictly less than j. This is quite beautiful: the same code can be
   used in the pre-computation phase and in the search phase!

   However, this loop has time complexity O(m^2), which is not good. It can be
   improved, as follows. *)

(* When j is 1, we are supposed to search at offset 1 in a string of length 1,
   so the search stops immediately and the result is the initial state: 0. *)

(* When j > 1, we can again split the search (loop table pattern m pattern j 0 1)
   as a succession of two steps, and write:

     for j = 2 to m - 1 do
       let s = assertInterrupted (loop table pattern m pattern (j - 1) 0 1) in
       table.(j) <- assertInterrupted (loop table pattern m pattern j s (j - 1))
     done

   The first line of the loop body can then be simplified, since the value s has
   just been computed in the previous iteration: indeed, s is just table.(j - 1).
   We obtain the code below: *)

let init pattern m : int array =
  let table = Array.create m 0 in
  (* table.(0) is never used *)
  (* table.(1) is initialized to 0 *)
  for j = 2 to m - 1 do
    let s = table.(j - 1) in
    table.(j) <- assertInterrupted (loop table pattern m pattern j s (j - 1))
  done;
  table

(* Complexity argument (informal): the whole loop above is equivalent to just
   one call to loop table pattern m pattern m 0 1, except it stops at every
   input offset j to record its current state in table.(j). So its complexity
   is O(m). *)

(* Note: an optimization that allows building shortcuts in the table is not
   implemented here. So this is almost, but not quite, the Knuth-Morris-Pratt
   algorithm. *)

(* -------------------------------------------------------------------------- *)

(* Combine everything to obtain a complete, linear-time search algorithm. *)

let search pattern text : result =
  let m = String.length pattern
  and n = String.length text in
  let table = init pattern m in
  loop table pattern m text n 0 0