(* ------------------------------------------------------------
 * Description
 * ------------------------------------------------------------ *)

(* Boring boilerplate code to map over syntax trees.
 * In the Ivy infrastructure this will be dealt with using dynamic types *)



open Syntax
open Util


(* ------------------------------------------------------------
 * Mapper Functions
 * ------------------------------------------------------------ *)

type mp = {
	m_pgrm : mp -> program -> program;
	m_extd : mp -> extdeclaration -> extdeclaration;
	m_cnst : mp -> const -> const;
	m_stmt : mp -> statement -> statement;
	m_expr : mp -> expression -> expression;
	m_blck : mp -> block -> block;
	m_patt : mp -> pattern -> pattern;
	m_jump : mp -> jumpdetails -> jumpdetails;
	m_decl : mp -> declaration -> declaration;
	m_type : mp -> typ -> typ;
	m_btyp : mp -> basetyp -> basetyp;
	m_tini : mp -> tinitialiser -> tinitialiser;
	m_idcl : mp -> init_declarator -> init_declarator;
	m_dtor : mp -> declarator -> declarator;
	m_init : mp -> initialiser -> initialiser;
	m_ctyp : mp -> coretyp -> coretyp;
	m_stdt : mp -> structdetails -> structdetails;
	m_endt : mp -> enumdetails -> enumdetails;
	m_ardt : mp -> argdetails -> argdetails;
	m_dmod : mp -> declmod -> declmod;
	m_idef : mp -> ifacedef -> ifacedef;
	m_ddef : mp -> dictdef -> dictdef;
	m_fdef : mp -> fundef -> fundef;
	m_tyif : mp -> tyiface -> tyiface;
	m_dbod : mp -> dictbody -> dictbody;
	m_qual : mp -> specqual -> specqual;
	m_mknd : mp -> macrokind -> macrokind;
	m_id : id -> id;
}

let m_stmtopt mp = option_apply (mp.m_stmt mp) 
let m_expropt mp = option_apply (mp.m_expr mp)
let m_idopt mp = option_apply mp.m_id
let m_exporef mp r = ref (option_apply (mp.m_expr mp) !r) 

let map_pgrm mp p = List.map (mp.m_extd mp) p

let map_extd mp (m,extd) = m, (match extd with
	| Include s -> Include s
	| Interface (id,idef) -> Interface (mp.m_id id,mp.m_idef mp idef)
	| Dict dictdef -> Dict (mp.m_ddef mp dictdef)
	| Decl decl -> Decl (mp.m_decl mp decl)
	| Func fundef -> Func (mp.m_fdef mp fundef) 
	| MacroType kind -> MacroType (mp.m_mknd mp kind)
	| StdType decl -> StdType (mp.m_decl mp decl)
	| NonDet (key,opts,c) -> 
		NonDet (key, List.map (mp.m_extd mp) opts, mp.m_extd mp c)
	| DSemicolon -> DSemicolon)

let map_mknd mp kind = match kind with
	| MDecl d -> MDecl (mp.m_decl mp d)
	| MSilent ids -> MSilent (List.map mp.m_id ids)
	| MSilentFun ids -> MSilentFun (List.map mp.m_id ids)

let map_cnst mp cnst = match cnst with
	| ConstSizeTy typ -> ConstSizeTy (mp.m_type mp typ)
	| ConstSizeExp e -> ConstSizeExp (mp.m_expr mp e)
	| _ -> cnst

let map_stmt mp (m,stmt) = m, (match stmt with
	| Label (id,s) -> Label (id, mp.m_stmt mp s)
	| SExp e -> SExp (mp.m_expr mp e)
	| Block b -> Block (mp.m_blck mp b)
	| Switch (e,s) -> Switch (mp.m_expr mp e,mp.m_stmt mp s)
	| Case (p,s) -> Case (mp.m_patt mp p,mp.m_stmt mp s)
	| If (e,st,sf) -> If (mp.m_expr mp e,mp.m_stmt mp st, m_stmtopt mp sf)
	| While (e,s) -> While (mp.m_expr mp e,mp.m_stmt mp s)
	| Do (s,e) -> Do (mp.m_stmt mp s,mp.m_expr mp e)
	| For (init,cond,inc,body) -> 
		For (m_expropt mp init, m_expropt mp cond, m_expropt mp inc, 
			mp.m_stmt mp body)
	| Jump jmp -> Jump (mp.m_jump mp jmp) 
	| Semicolon -> Semicolon)
	
let map_expr mp (m,expr) = m,(match expr with
	| JklNonDet (k,opts,c) -> 
			JklNonDet (k,List.map (mp.m_expr mp) opts,mp.m_expr mp c)
	| LocalFun (a,b) -> LocalFun (List.map (mp.m_decl mp) a,mp.m_blck mp b)
	| FunCall (e,es) -> FunCall (mp.m_expr mp e,List.map (mp.m_expr mp) es)
	| Var id -> Var (mp.m_id id)	
	| Field (b,e,id) -> Field (b,mp.m_expr mp e,mp.m_id id)
	| Const c -> Const (mp.m_cnst mp c)
	| Cast (t,e) -> Cast (mp.m_type mp t,mp.m_expr mp e)
	| Index (ea,ei) -> Index (mp.m_expr mp ea,mp.m_expr mp ei)
	| Choice (ec,et,ef) -> 
			Choice (mp.m_expr mp ec,mp.m_expr mp et,mp.m_expr mp ef)
	| Assign (el,op,er) -> Assign (mp.m_expr mp el,op,mp.m_expr mp er)
	| Parens e -> Parens (mp.m_expr mp e)
	| Init ti -> Init (mp.m_tini mp ti)
	| Unsafe e -> Unsafe (mp.m_expr mp e)
	| EBlock b -> EBlock (mp.m_blck mp b))
	
let map_blck mp (ds,ss) = 
	(List.map (mp.m_decl mp) ds,List.map (mp.m_stmt mp) ss)

let map_patt mp pat = match pat with
	| PTag (id,idopt) -> PTag (mp.m_id id, m_idopt mp idopt)
	| PDefault -> PDefault
	| PConst e -> PConst (mp.m_expr mp e)
	
let map_jump mp jmp = match jmp with
	| JBreak | JContinue -> jmp
	| JReturn eo -> JReturn (m_expropt mp eo)
	| JRet eo -> JRet (m_expropt mp eo)
	| JGoto id -> JGoto (mp.m_id id)

let map_btyp mp (quals,coret) = 
	List.map (mp.m_qual mp) quals,mp.m_ctyp mp coret

let map_decl mp (m,(basetyp,initds)) = 
		m, (mp.m_btyp mp basetyp,List.map (mp.m_idcl mp) initds) 
			
let map_type mp (quals,coret,declmods) = 
		let quals,coret = mp.m_btyp mp (quals,coret) in
		quals,coret,List.map (mp.m_dmod mp) declmods
		
let map_tini mp (m,tinit) = m, (match tinit with
	| TConApp (id,e) -> TConApp (mp.m_id id, option_apply (mp.m_expr mp) e)	
	| TStruct flds -> TStruct (List.map (fun (s,e) -> (s,mp.m_expr mp e)) flds)
	| TAlloc (kind,e) -> TAlloc (kind,mp.m_expr mp e))
	
let map_idcl mp (dtor,inito) = 
	mp.m_dtor mp dtor, option_apply (mp.m_init mp) inito

let map_dtor mp (dmods,ido) = 
	List.map (mp.m_dmod mp) dmods, option_apply mp.m_id ido 

let map_init mp init = match init with
	| IExp e -> IExp (mp.m_expr mp e)
	| IFields (inits,b) -> IFields (List.map (mp.m_init mp) inits,b)
	
let map_ctyp mp ctyp = match ctyp with
	| TyBasic | TyVoid -> ctyp
	| TyWild id -> TyWild (mp.m_id id)
	| TyName id -> TyName (mp.m_id id)
	| TyStruct (kind,ido,details) -> TyStruct (kind,option_apply mp.m_id ido, 
								option_apply (mp.m_stdt mp) details)
	| TyEnum (ido,details) -> TyEnum (option_apply mp.m_id ido, 
								option_apply (mp.m_endt mp) details)
	| TyTypeofExp (b,e) -> TyTypeofExp (b,mp.m_expr mp e)
	| TyTypeofTyp (b,t) -> TyTypeofTyp (b,mp.m_type mp t)
		
let map_stdt mp (m,tyvars,decls) = 
	m,List.map mp.m_id tyvars, List.map (mp.m_decl mp) decls

let map_endt mp entries = 
	List.map (fun (id,eop) -> (mp.m_id id, m_expropt mp eop)) entries
	
let map_dmod mp dmod = match dmod with
	| DPtr quals -> DPtr (List.map (mp.m_qual mp) quals)
	| DArray eo -> DArray (m_expropt mp eo)
	| DBitField e -> DBitField (mp.m_expr mp e)
	| DFun (details,kind) ->  DFun (mp.m_ardt mp details,kind)
	| DWithArgs typs -> DWithArgs (List.map (mp.m_type mp) typs)
	| DParens -> DParens	
	| DFatPtr quals -> DFatPtr (List.map (mp.m_qual mp) quals)

let map_ardt mp details = match details with
	| ArgsFull (m,ifaces,decls) -> 
		ArgsFull (m,List.map (mp.m_tyif mp) ifaces, 
			List.map (mp.m_decl mp) decls)	
	| ArgsNamed ids -> ArgsNamed (List.map mp.m_id ids)
	| ArgsNoinfo -> ArgsNoinfo
	
let map_idef mp (id,funsigs) = mp.m_id id, List.map (mp.m_decl mp) funsigs

let map_ddef mp (typ,iface,body) = 
	mp.m_type mp typ, mp.m_id iface, mp.m_dbod mp body
	
let map_fdef mp (decl,krs,block) = 
	mp.m_decl mp decl, List.map (mp.m_decl mp) krs, mp.m_blck mp block
	
let map_tyif mp (id,ifacedef) = mp.m_id id, mp.m_id ifacedef

let map_dbod mp dictbody = match dictbody with
	| DictProto tyifs -> DictProto (List.map (mp.m_tyif mp) tyifs)
	| DictImpl (tyifs,mfdefs) -> DictImpl (
		List.map (mp.m_tyif mp) tyifs,
		List.map (fun (m,fdef) -> (m,mp.m_fdef mp fdef)) mfdefs)

let nothing mp x = x

let mp_default = {
	m_extd = map_extd; m_cnst = map_cnst; m_stmt = map_stmt; 
	m_expr = map_expr; m_blck = map_blck; m_patt = map_patt;
	m_jump = map_jump; m_decl = map_decl; m_type = map_type;
	m_btyp = map_btyp; m_tini = map_tini; m_idcl = map_idcl;
	m_init = map_init; m_ctyp = map_ctyp; m_dtor = map_dtor;
	m_stdt = map_stdt; m_endt = map_endt; m_dmod = map_dmod;
	m_idef = map_idef; m_ardt = map_ardt; m_ddef = map_ddef;
	m_fdef = map_fdef; m_tyif = map_tyif; m_dbod = map_dbod;
	m_mknd = map_mknd; m_pgrm = map_pgrm;
	m_id   = identity; m_qual = nothing;
}

let mp_nothing = {
	m_extd = nothing; m_cnst = nothing; m_stmt = nothing;
	m_expr = nothing; m_blck = nothing; m_patt = nothing;
	m_jump = nothing; m_decl = nothing; m_type = nothing;
	m_btyp = nothing; m_tini = nothing; m_idcl = nothing;
	m_init = nothing; m_ctyp = nothing; m_dtor = nothing;
	m_stdt = nothing; m_endt = nothing; m_dmod = nothing;
	m_idef = nothing; m_ardt = nothing; m_ddef = nothing;
	m_fdef = nothing; m_tyif = nothing; m_dbod = nothing;
	m_mknd = nothing; m_pgrm = nothing;
	m_qual = nothing; m_id = identity
}	

let map_program mp p = mp.m_pgrm mp p		


let rec fseq funcs x = match funcs with
	| [] -> x
	| f::fs -> fseq	fs (f x)
	
let seq funcs dflt mp x = dflt mp (fseq funcs x)
let wseq before dflt after mp x = fseq after (seq before dflt mp x)	
	
