Below is the file 'automate.ml' from this revision. You can also download the file.

open Viz_misc

let init =
  Giochannel.init ; Gspawn.init

let debug = Viz_misc.debug "automate"

let log fmt =
  Printf.kprintf
    (fun s -> Printf.eprintf "### automate: %s\n%!" s)
    fmt


(** Type definitions *)

type pb = [
  | `HANGUP
  | `FAILURE
  | `ERROR of exn ]
type watch_state = [
  | `DISABLED
  | `WATCH of Giochannel.source_id
  | pb ]
type watch = {
    w_name       : string ;
    w_chan       : Giochannel.t ;
    mutable w_state    : watch_state ;
    mutable exn_cb     : pb -> unit ;
  }
type in_watch = {
    in_w         : watch ;
    mutable in_data : (int * string) list ;
    mutable in_pos  : int ;
  }
type out_watch = {
    out_w        : watch ;
    out_sb       : string ;
    out_buffer   : Buffer.t ;
    mutable out_cb : (Buffer.t -> unit)
  }

type command_id = int
type output = [
  | `OUTPUT       of string
  | `ERROR        of string
  | `SYNTAX_ERROR of string]
type chunk = command_id * int * bool * string
type process = {
    p_in  :  in_watch ;
    p_out : out_watch ;
    p_err : out_watch ;
    mutable cmd_number : command_id ;
    mutable callbacks  : (command_id * (output -> unit)) list ;
    mutable chunks     : (command_id * chunk list ref) list ;
    mutable state      : [`RUNNING|`EXITING|`EXIT of int] ;
    mutable exit_cb    : (pb -> string -> unit) ;
  }


type t = {
    mtn      : string ;
    db_fname : string ;
    mutable process : process option ;
  }

let get_info c =
  c.mtn, c.db_fname
let get_dbfname c =
  c.db_fname



let string_of_conditions cond =
  let s = String.make 6 '.' in
  Array.iteri
    (fun i (v, c) -> if List.mem v cond then s.[i] <- c)
    [|  `IN, 'I' ; `OUT, 'O' ;
       `HUP, 'H' ; `ERR, 'E' ;
       `PRI, 'P' ; `NVAL, 'N' |] ;
  s
let string_of_pb = function
  | `ERROR exn -> Printf.sprintf "EXN '%s'" (Printexc.to_string exn)
  | `FAILURE   -> "ERR"
  | `HANGUP    -> "HUP"




let error_cb w conditions =
  if debug then  log "%s hup_cb = %s" w.w_name (string_of_conditions conditions) ;
  if List.mem `ERR conditions
  then begin
    w.exn_cb `FAILURE
  end else begin
    assert (conditions = [`HUP]) ;
    w.exn_cb `HANGUP
  end


let do_write w data =
  let bytes_written = ref 0 in
  try
    match Giochannel.write_chars w.in_w.w_chan ~bytes_written ~off:w.in_pos data with
    | `NORMAL written ->
	if debug then log "%s cb: wrote %d" w.in_w.w_name written ;
	w.in_pos <- w.in_pos + written ;
	w.in_pos >= String.length data
    | `AGAIN ->
	(* should not happen, our out channels are blocking *)
	if debug then log "%s cb: EAGAIN ?" w.in_w.w_name ;
	false
  with
  | Giochannel.Error (_, msg) as exn ->
      if debug then log "%s cb: error %s, wrote %d" w.in_w.w_name msg !bytes_written ;
      w.in_w.exn_cb (`ERROR exn) ;
      false

let _write_cb w conditions =
  if debug then log "%s cb = %s" w.in_w.w_name (string_of_conditions conditions) ;
  match w.in_data with
  | [] ->
      (* nothing to write, remove the source from the main loop *)
      if debug then log "%s cb: empty write queue, removing watch" w.in_w.w_name ;
      w.in_w.w_state <- `DISABLED

  | (nb, data) :: tl ->
      (* some data to write *)
      let len  = String.length data in
      let to_write = len - w.in_pos in
      assert (len > 0 && to_write > 0) ;
      if debug then log "%s cb: %d left in buffer" w.in_w.w_name to_write ;
      if debug && w.in_pos = 0 then log "%s cb: writing '%s'" w.in_w.w_name (String.escaped data) ;

      if List.mem `OUT conditions
      then begin
	let finished = do_write w data in
	if finished
	then begin
	  if debug then log "%s cb: finished writing cmd %d" w.in_w.w_name nb ;
	  (* written everything, proceed to the next chunk *)
	  w.in_data <- tl ;
	  w.in_pos <- 0
	end
      end
      else
	error_cb w.in_w conditions

let _read_cb w conditions =
  if debug then  log "%s cb = %s" w.out_w.w_name (string_of_conditions conditions) ;
  if List.mem `IN conditions
  then begin
    try
      match Giochannel.read_chars w.out_w.w_chan w.out_sb with
      | `NORMAL read ->
	  if debug then log "%s cb: read %d" w.out_w.w_name read ;
	  Buffer.add_substring w.out_buffer w.out_sb 0 read ;
	  w.out_cb w.out_buffer
      | `EOF ->
	  if debug then log "%s cb: eof ?" w.out_w.w_name ;
	  w.out_w.exn_cb `FAILURE
      | `AGAIN ->
 	  if debug then log "%s cb: AGAIN" w.out_w.w_name
    with exn ->
      if debug then log "%s cb: error %s" w.out_w.w_name (Printexc.to_string exn) ;
      w.out_w.exn_cb (`ERROR exn)
  end
  else
    error_cb w.out_w conditions


let reschedule_watch w =
  match w.w_state with
  | `WATCH _ -> true
  | _ -> false

let write_cb w c =
  try _write_cb w c ; reschedule_watch w.in_w
  with exn ->
    if debug
    then log "write cb %s: uncaught exception '%s'" w.in_w.w_name (Printexc.to_string exn) ;
    true

let read_cb w c =
  try _read_cb w c ; reschedule_watch w.out_w
  with exn ->
    if debug
    then log "read cb %s: uncaught exception '%s'" w.out_w.w_name (Printexc.to_string exn) ;
    true


let setup_watch_write w nb data =
  match w.in_w.w_state with
  | `DISABLED ->
      assert (w.in_data = []) ;
      w.in_data <- [ nb, data ] ;
      w.in_pos  <- 0 ;
      let id = Giochannel.add_watch w.in_w.w_chan [ `OUT ; `HUP ; `ERR ] (write_cb w) in
      w.in_w.w_state <- `WATCH id ;
  | `WATCH _ ->
      w.in_data <- w.in_data @ [ nb, data ]
  | _ ->
      assert (false)


let setup_watch_read w =
  assert (w.out_w.w_state = `DISABLED) ;
  let id = Giochannel.add_watch w.out_w.w_chan [ `IN ; `HUP ; `ERR ] (read_cb w) in
  w.out_w.w_state <- `WATCH id


let setup_exn_cb w cb =
  w.exn_cb <- cb w

let setup_channel ~nonblock fd =
  let chan = Giochannel.new_fd (some fd) in
  if nonblock then Giochannel.set_flags_noerr chan [`NONBLOCK] ;
  Giochannel.set_encoding chan None ;
  Giochannel.set_buffered chan false ;
  chan
let make_watch name chan =
  { w_name = name ; w_chan = chan ; w_state = `DISABLED ; exn_cb = ignore }
let make_in_watch  name fd =
  let chan = setup_channel ~nonblock:true fd in
  { in_w = make_watch name chan ; in_data = [] ; in_pos = 0 }
let make_out_watch name fd =
  let chan = setup_channel ~nonblock:false fd in
  let w = {
    out_w = make_watch name chan ;
    out_sb = String.create 4096 ;
    out_buffer = Buffer.create 1024 ;
    out_cb = ignore
  } in
  setup_watch_read w ;
  w








let send_data p nb data =
  if String.length data = 0
  then invalid_arg "Automate.send_data: empty string" ;
  setup_watch_write p.p_in nb data



let encode_stdio cmd =
  let b = Buffer.create 512 in
  Buffer.add_char b 'l' ;
  List.iter
    (fun s -> Printf.bprintf b "%d:%s" (String.length s) s)
    cmd ;
  Buffer.add_string b "e\n" ;
  Buffer.contents b



let find_four_colon b =
  let to_find = ref 4 in
  let i = ref 0 in
  while !to_find > 0 do
    let c = Buffer.nth b !i in
    if c = ':' then decr to_find ;
    incr i
  done ;
  !i

let truncate_buffer b off len =
  let data = Buffer.sub b off len in
  let rest = Buffer.sub b (off + len) (Buffer.length b - off - len) in
  Buffer.clear b ;
  Buffer.add_string b rest ;
  data

let decode_stdio_chunk b =
  try
    let header_len = find_four_colon b in
    let h = Buffer.sub b 0 header_len in
    let c1 = String.index_from h 0 ':' in
    let number = int_of_string (string_slice ~e:c1 h) in
    let code   = int_of_char h.[c1 + 1] - int_of_char '0' in
    let c2 = String.index_from h (c1 + 1) ':' in
    let last   = h.[c2 + 1] in
    let c3 = String.index_from h (c2 + 1) ':' in
    let c4 = String.index_from h (c3 + 1) ':' in
    let len   = int_of_string (string_slice ~s:(c3 + 1) ~e:c4 h) in
    if Buffer.length b < header_len + len
    then
      `INCOMPLETE
    else
      let data = truncate_buffer b header_len len in
      `CHUNK (number, code, last = 'l', data)
  with Invalid_argument _ ->
    `INCOMPLETE


let aborted_cmd p nb =
  not (List.mem_assoc nb p.callbacks)

let rec out_cb p b =
  match decode_stdio_chunk b with
  | `INCOMPLETE ->
      ()

  | `CHUNK (nb, _, _, _) when aborted_cmd p nb ->
      p.chunks <- List.remove_assoc nb p.chunks ;
      out_cb p b

  | `CHUNK ((nb, code, false, data) as chunk) ->
      if debug then log "decoded a chunk" ;
      let previous_chunks =
	try List.assoc nb p.chunks
	with Not_found ->
	  let c = ref [] in
	  p.chunks <- (nb, c) :: p.chunks ;
	  c in
      previous_chunks := chunk :: !previous_chunks ;
      out_cb p b

  | `CHUNK ((nb, code, true, data) as chunk) ->
      if debug then log "decoded last chunk" ;
      let chunks =
	try
	  let c = List.assoc nb p.chunks in
	  p.chunks <- List.remove_assoc nb p.chunks ;
	  List.rev (chunk :: !c)
	with Not_found ->
	  [ chunk ] in
      let cb = List.assoc nb p.callbacks in
      p.callbacks <- List.remove_assoc nb p.callbacks ;
      let msg =
	String.concat ""
	  (List.map (fun (_, _, _, d) -> d) chunks) in
      let data =
	match code with
	| 0 -> `OUTPUT msg
	| 1 -> `SYNTAX_ERROR msg
	| 2 -> `ERROR msg
	| _ -> failwith "invalid_code in automate stdio output" in
      ignore (Glib.Idle.add ~prio:0 (fun () -> cb data ; false)) ;
      out_cb p b



let check_exit p =
  match p.state with
  | `RUNNING
  | `EXITING ->
      ()
  | `EXIT _ ->
      let stderr = Buffer.contents p.p_err.out_buffer in
      let r =
	if p.p_in.in_w.w_state <> `HANGUP
	then p.p_in.in_w.w_state
	else if p.p_out.out_w.w_state <> `HANGUP
	then p.p_out.out_w.w_state
	else if p.p_err.out_w.w_state <> `HANGUP
	then p.p_err.out_w.w_state
	else `HANGUP in
      match r with
      | #pb as r -> p.exit_cb r stderr
      | _ -> ()

let exn_cb p w r =
  if debug then log "%s exn_cb: %s" w.w_name (string_of_pb r) ;
  if p.state = `RUNNING then p.state <- `EXITING ;
  Giochannel.shutdown w.w_chan false ;
  w.w_state <- (r : pb :> watch_state) ;
  check_exit p

let reap_cb p pid st =
  if debug then log "reap_cb: %d" st ;
  Gspawn.close_pid pid ;
  if p.p_in.in_w.w_state = `DISABLED
  then exn_cb p p.p_in.in_w `HANGUP ;
  p.state <- `EXIT st ;
  check_exit p




let _submit p cmd cb =
  Viz_misc.log "mtn" "sending command '%s'" (String.concat " " cmd) ;
  let id = p.cmd_number in
  send_data p id (encode_stdio cmd) ;
  p.cmd_number <- id + 1 ;
  p.callbacks  <- (id, cb) :: p.callbacks ;
  id


let spawn mtn db =
  let cmd = [ mtn ; "-d" ; db ; "automate" ; "stdio" ] in
  if Viz_misc.debug "exec"
  then Printf.eprintf "### exec: Running '%s'\n%!" (String.concat " " cmd) ;
  let flags =
    [ `PIPE_STDIN ; `PIPE_STDOUT ; `PIPE_STDERR ;
      `SEARCH_PATH ; `DO_NOT_REAP_CHILD] in
  let child = Gspawn.async_with_pipes ~flags cmd in
  let p =
    { p_in  = make_in_watch  "stdin"  child.Gspawn.standard_input  ;
      p_out = make_out_watch "stdout" child.Gspawn.standard_output ;
      p_err = make_out_watch "stderr" child.Gspawn.standard_error  ;
      state = `RUNNING ;
      cmd_number = 0 ;
      callbacks = [] ;
      chunks = [] ;
      exit_cb = (fun _ -> assert false)
    } in
  let pid = some child.Gspawn.pid in
  ignore (Gspawn.add_child_watch ~prio:50 pid (reap_cb p pid)) ;
  p.p_out.out_cb <- out_cb p ;
  setup_exn_cb p.p_in.in_w   (exn_cb p) ;
  setup_exn_cb p.p_out.out_w (exn_cb p) ;
  setup_exn_cb p.p_err.out_w (exn_cb p) ;
  p





let exit_cb ctrl p r stderr =
  if debug then log "exit_cb: r='%s' stderr='%s'" (string_of_pb r) stderr ;
  (* display dialog box in case of error ... *)
  match ctrl.process with
  | Some p' when p' == p ->
      ctrl.process <- None ;
      List.iter (fun (_, cb) -> cb (`ERROR stderr)) p.callbacks
  | _ ->
      ()


let ensure_process ctrl =
  match ctrl.process with
  | Some ({ state = `RUNNING } as p) ->
      p
  | Some { state = `EXITING | `EXIT _ }
  | None ->
      let p = spawn ctrl.mtn ctrl.db_fname in
      p.exit_cb <- exit_cb ctrl p ;
      ctrl.process <- Some p ;
      p









let make mtn db = {
  mtn = mtn ;
  db_fname = db ;
  process = None
}

let exit ctrl =
  match ctrl.process with
  | Some ({ state = `RUNNING } as p) ->
      if debug then log "forced exit" ;
      let w = p.p_in.in_w in
      begin
	match w.w_state with
	| `WATCH id ->
	    Giochannel.remove_watch id
	| _ -> ()
      end ;
      Giochannel.shutdown w.w_chan false ;
      w.w_state <- `HANGUP
  | Some { state = `EXITING | `EXIT _ }
  | None  ->
      ()


let submit ctrl cmd cb =
  _submit (ensure_process ctrl) cmd cb

let submit_sync ctrl cmd =
  let output = ref None in
  let exit_loop = ref false in
  let _ =
    submit
      ctrl cmd
      (fun v -> output := Some v ; exit_loop := true) in
  while not !exit_loop do
    ignore (Glib.Main.iteration true)
  done ;
  match some !output with
  | `OUTPUT msg ->
      msg
  | `ERROR msg
  | `SYNTAX_ERROR msg ->
      Viz_types.errorf "mtn automate error: %s" msg



let abort ctrl nb =
  match ctrl.process with
  | None ->
      ()
  | Some p ->
      p.callbacks <- List.remove_assoc nb p.callbacks ;
      match p.p_in.in_data with
      | (id, _) :: tl when id = nb ->
	  if p.p_in.in_pos = 0
	  then begin
	    p.p_in.in_data <- tl ;
	    p.p_in.in_pos <- 0
	  end
      | h :: tl ->
	  p.p_in.in_data <- h :: (List.remove_assoc nb tl)
      | [] ->
	  ()


(* TODO:
   - add a timeout to exit the subprocess in case of inactivity
   - add a submit_delayed to submit a cancellable command
     with a small timeout (for keyboard nav)
   - check exceptions and callbacks
   - add asserts and sanity checks
 *)