open Util
open Metadata
open Syntax
open Location
open Syntaxutil
open Prettyutil
open Prettyprint
open Errors
open Names


(* ------------------------------------------------------------
 * Utility bits
 * ------------------------------------------------------------ *)

module SH = Data.StringCols.Hash
module SS = Data.StringCols.ImpSet
type 'a strhash = 'a SH.t
type strset = SS.t


(* ------------------------------------------------------------
 * Type Environment
 * ------------------------------------------------------------ *)

type 'a scope = ('a strhash * 'a strhash option) list

let newscope parent = (SH.create (), None)::parent
let rootscope () = newscope []
let lambdascope set parent = (SH.create (), Some set)::parent


(* definitions in scope *)
type definfo = {
	typscopes : typ scope;
	declscopes : declaration scope;
	ifacedefs : ifacedef strhash;
	dictdecls : dictdecl list strhash;
	structdefs : structdetails strhash;
	taggeddefs : structdetails strhash;
	uniondefs : structdetails strhash;
	enumdefs : enumdetails strhash;
	typedefs : typ strhash;
	constrdefs : (id list * typ * id) strhash;
	methoddefs : id strhash;
	stdtypes : typ strhash;
	}
	
(* info about what we are currently checking *)
type context = {
	rettype : typ option ref;
	conttype : typ option ref;
	scrutinee : (expression * typ) option;
	typarams : ifacespec list strhash list;
	dicttyvars : strset;
	thisdictenv : dictenv;
	unsafe : bool;
	expect_typ : typ option;
	temps : tinitialiser list ref;
	fwddecls : (namevar * typ) list ref;
	tempdecls : (namevar * typ) list ref;
	lamenvs : localfuninfo list ref;
	lambdas : lamdef list ref;
	dictenvs : (string * id * id * dictenv) list ref;
	in_tinit : bool;
	repchanged : bool ref;
	tycmp_context : unit -> pretty;
	curstmt : statement option;
	curdecl : declaration option;
	}
	
(* counters to generate names from *)
type naminginfo = {
	nextdictenvnum : int ref;
	nextfunenvnum : int ref;
	nexttempnum : int ref;
	nextlambdanum : int ref;
	basename : string;
}

type env = 
	{defs : definfo;
	 context : context;
	 naminginfo : naminginfo;
	 meta : metadata 
	 }


(* ------------------------------------------------------------
 * Create an empty type environment
 * ------------------------------------------------------------ *)

let new_defs () = {
			typscopes = rootscope ();
			declscopes = rootscope ();
			ifacedefs = SH.create ();
			dictdecls = SH.create ();
			structdefs = SH.create ();
			taggeddefs = SH.create ();
			uniondefs = SH.create ();
			enumdefs = SH.create ();
			typedefs = SH.create ();
			constrdefs = SH.create ();
			methoddefs = SH.create ();
			stdtypes = SH.create ();
			}			

let new_context () = {
			rettype = ref None;
			conttype = ref None;
			scrutinee = None;
			typarams = [];
			dicttyvars = SS.create ();
			thisdictenv = [];
			unsafe = false;
			expect_typ = None;
			temps = ref [];
			fwddecls = ref [];
			tempdecls = ref [];
			in_tinit = false;
			repchanged = ref false;
			lambdas = ref [];
			lamenvs = ref [];
			dictenvs = ref [];
			tycmp_context = (fun () -> empty);
			curstmt = None;
			curdecl = None;
			}

let new_naminginfo () = {
			nextdictenvnum = ref 0;
			nextfunenvnum = ref 0;
			nexttempnum = ref 0;
			nextlambdanum = ref 0;
			basename = "";
			}

let new_env () = {
			defs = new_defs ();
			context = new_context ();
			naminginfo = new_naminginfo ();
			meta = nometa ();
			}
			

(* ------------------------------------------------------------
 * Error Handling
 * ------------------------------------------------------------ *)

let tyerr env pretty = with_suspend_errors (fun () ->
	let jstmt = match !Cmdargs.j2c, 
				env.context.curstmt, env.context.curdecl with
		| false, Some s,_ -> newline <+> str "" <+> newline <+> 
					str "Statement decoded as:"<++> pprint_stmt s
		| false,_, Some d -> newline <+> str "" <+> newline <+>
					str "Declaration decoded as:"<++> pprint_decl d
		| _ -> empty in
	pfatal env.meta (pretty <+> jstmt))

let tywrong env ty msg = tyerr env (str "Type"<++>pp_ty ty<++>msg)

let tywarn env pretty = pwarning env.meta pretty

let error_arg_mismatch env = tyerr env (str "Function Argument Mismatch")

let error_mismatch env reason = 
	tyerr env (env.context.tycmp_context () <+> newline <+> reason)

let idtyerr id pretty = pfatal (fst id) pretty


(* ------------------------------------------------------------
 * Pretty print errors
 * ------------------------------------------------------------ *)

let pprint_type_mismatch expector tyfound tywant () =
	str "type mismatch" <+> newline <+>
	str "type is:" <++> pp_ty tyfound <+> newline <+>
	str "but"<++>expector<++>str "expects:" <+>
		pp_ty tywant

let pprint_funsig_mismatch sigfound sigwant () =	
	str "method implementation does not match interface" <+> newline <+> 
	str "found:" <++> pprint_funsig "" sigfound <+> newline <+> 
	str "required:" <++> pprint_funsig "" sigwant


(* ------------------------------------------------------------
 * Nested Environments
 * ------------------------------------------------------------ *)

let localenv env = {env with 
		defs = {env.defs with
				declscopes = newscope env.defs.declscopes;
				typscopes = newscope env.defs.typscopes;};
		context = {env.context with
				typarams = SH.create () :: env.context.typarams}}

let withlambdavars env varset = {env with
		defs = {env.defs with
				typscopes = lambdascope varset env.defs.typscopes;
				declscopes = newscope env.defs.declscopes}}

let withmeta env meta = {env with
		meta = meta; 
		context = {env.context with in_tinit = false;}}
		
let withstmt env (meta,_ as stmt) = withmeta {env with 
		context = {env.context with curstmt = Some stmt}} meta
		
let withdecl env (meta,_ as decl) = withmeta {env with
		context = {env.context with curdecl = Some decl}} meta		
		
let withdictenv env dictenv = {env with
		context = {env.context with thisdictenv = dictenv}}		
	
let withunsafe env = {env with context = {env.context with unsafe = true}}		

let with_expect_typ env typ = {env with
		context = {env.context with expect_typ = Some typ}}
		
let with_unknown_typ env = {env with
		context = {env.context with expect_typ = None}}

let withtemps env = {env with
		context = {env.context with temps = ref []}}
		
let withfwddecls env = {env with
		context = {env.context with fwddecls = ref []; tempdecls = ref [];
			 lamenvs = ref []; dictenvs = ref []};
		naminginfo = {env.naminginfo with nexttempnum = ref 0;
			nextdictenvnum = ref 0;}}
		
let withlambdas env id = {env with
		context = {env.context with lambdas = ref []};
		naminginfo = {env.naminginfo with nextlambdanum = ref 0; 
			basename = id_str id}}		

let with_tinit env = {env with
		context = {env.context with in_tinit = true}}		
		
let with_repcheck env = {env with 
		context = {env.context with repchanged = ref false;}}		
		
let with_unify env whatcheck = {env with
		context = {env.context with tycmp_context = whatcheck}}

let with_scrutinee env eo = {env with
		context = {env.context with scrutinee = eo}}


(* ------------------------------------------------------------
 * Imperative update
 * ------------------------------------------------------------ *)

let set_repchanged env = env.context.repchanged := true
let set_rettype env t = env.context.rettype := t
let set_conttype env t = env.context.conttype := t


(* ------------------------------------------------------------
 * Get the current state
 * ------------------------------------------------------------ *)

let get_repchanged env = !(env.context.repchanged)
let get_meta env = env.meta
let get_funlambdas env = !(env.context.lambdas) 
let get_conttype env = !(env.context.conttype)
let get_rettype env = !(env.context.rettype)
let get_expectty env = env.context.expect_typ
let get_block_fwds env = !(env.context.fwddecls)
let get_block_temps env = !(env.context.tempdecls)
let get_block_lamenvs env = !(env.context.lamenvs)
let get_block_dictenvs env = !(env.context.dictenvs)
let get_stmt_temps env = !(env.context.temps)
let get_unsafe_allowed env = env.context.unsafe

let get_scrutinee env = match env.context.scrutinee with
	| Some (var,typ) -> var,typ
	| None ->
		tyerr env (str "Not inside a switch statement")

let get_expect_typ env = match env.context.expect_typ with
	| Some typ -> typ
	| None -> tyerr env (str "Type is ambiguous - maybe use a cast")


(* ------------------------------------------------------------
 * Adding Things
 * ------------------------------------------------------------ *)

let hashadd id data hash = 
	if SH.mem (id_str id) hash then
		idtyerr id (str "Duplicate name : " <+> pp_id id)
	else
		SH.add (id_str id) data hash

let hashlist_add id data hash = 
	let old = SH.find_or_create (id_str id) (fun () -> []) hash in
	SH.add (id_str id) (data :: old) hash

let type_scope env structkind = match structkind with
	| SKStruct -> env.defs.structdefs
	| SKUnion -> env.defs.uniondefs
	| SKTagged -> env.defs.taggeddefs
	
let idadd hash id data = SH.add (id_str id) data hash
let idhadd hash id data = hashlist_add id data hash
	
let add_structdef env kind = idadd (type_scope env kind)
let add_enumdef env = idadd env.defs.enumdefs
let add_constrdef env = idadd env.defs.constrdefs
let add_typedef env = idadd env.defs.typedefs
let add_localdecl env = idadd (fst (List.hd env.defs.declscopes))
let add_localty env = idhadd (List.hd env.context.typarams)
let add_localvar env = idadd (fst (List.hd env.defs.typscopes))
let add_dictproto env = idhadd env.defs.dictdecls
let add_methoddef env = idadd env.defs.methoddefs
let add_ifacedef env = idadd env.defs.ifacedefs
let add_stdtype env = idadd env.defs.stdtypes
let add_dict_tyvar env id = SS.add (id_str id) env.context.dicttyvars


(* ------------------------------------------------------------
 * Find things
 * ------------------------------------------------------------ *)

let try_find env hash id kind = 
	if not (SH.mem (id_str id) hash) then 
		idtyerr id (str "No such " <+> str kind <+> str " : " <+> pp_id id)
	else SH.find (id_str id) hash

let rec find_id_in_scope env scope id = match scope with 
	| (vars,_)::_ when SH.mem (id_str id) vars ->
			SH.find (id_str id) vars
	| (_,Some lambdavars)::xs ->
			let res = find_id_in_scope env xs id in
			set_is_envvar env.meta;
			SH.add (id_str id) res lambdavars;
			res
	| (_,None)::xs -> find_id_in_scope env xs id
	| [] when Parseutils.is_define (id_str id) -> tyerr env (
			str "No type specified for macro:"<++> pp_id id)
	| [] -> tyerr env (str "No such variable in scope:"<++>pp_id id)		

(* find the definition of a type *)
let find_struct env id = try_find env env.defs.structdefs id "struct"
let find_iface env id = try_find env env.defs.ifacedefs id "interface"
let find_tagged env id = try_find env env.defs.taggeddefs id "tagged"
let find_typedef env id = try_find env env.defs.typedefs id "typedef"
let find_id_type env id = find_id_in_scope env env.defs.typscopes id
let find_id_decl env id = find_id_in_scope env env.defs.declscopes id
let find_constr env id = try_find env env.defs.constrdefs id "constructor"
let find_method_iface env id = try_find env env.defs.methoddefs id "method"
let find_typaram_ifaces env id = 
	concat_map (SH.find_or_create (id_str id) 
		(fun () -> [])) env.context.typarams

let find_dicts env id = SH.find_or_create (id_str id) 
		(fun () -> []) env.defs.dictdecls
let find_dict env tyid ifaceid = 
	let dicts = find_dicts env tyid in
	match List.filter (fun (_,id,_) -> id_equal id ifaceid) dicts with
	| [] -> tyerr env (str "No implementation prototype for" <++> 
					pp_id tyid <++> str ":" <++> pp_id ifaceid)
	| [d] -> d
	| _ -> tyerr env (str "Multiple prototypes for" <++> 
					pp_id tyid <++> str ":" <++> pp_id ifaceid)

let is_dict_tyvar env id = SS.mem (id_str id) env.context.dicttyvars
let is_global env id = 
	SH.mem (id_str id) (fst (list_last env.defs.typscopes)) ||
	member (id_str id) builtin_ops
let is_method env id = SH.mem (id_str id) env.defs.methoddefs
					

(* ------------------------------------------------------------
 * Create Temporaries
 * ------------------------------------------------------------ *)

let next_dictenv_num env = 
	let num = !(env.naminginfo.nextdictenvnum) in
	env.naminginfo.nextdictenvnum := num + 1;
	string_of_int num
	
let next_temp_num env = 
	let num = !(env.naminginfo.nexttempnum) in
	env.naminginfo.nexttempnum := num + 1;
	string_of_int num	
	
let next_lambda_num env = 
	let num = !(env.naminginfo.nextlambdanum) in
	env.naminginfo.nextlambdanum := num + 1;
	string_of_int num

let new_temp_name env = name_temp (next_temp_num env) 	
let new_dictenv_name env = name_dicttemp (next_dictenv_num env)

let new_dictenv env tyid ifaceid envargs =
	if envargs = [] then DictGlobal None else 
	if dictenv_equal envargs env.context.thisdictenv then DictThis else 
	begin
		let newname = new_dictenv_name env in
		env.context.dictenvs := (newname,tyid,ifaceid,envargs) ::
					 !(env.context.dictenvs);
		DictGlobal (Some newname)
	end

let new_fwddecl env id typ =
	env.context.fwddecls := 
		(new_fixed_namevar (id_str id),typ)::!(env.context.fwddecls)
		
let new_temp_namevar env =
	let name = new_temp_name env in
	new_namevar name "" 
		
let new_temp env typ (m,tinit) =
	if not env.context.in_tinit then begin
		let namevar = new_temp_namevar env in
		env.context.tempdecls := (namevar,typ)::!(env.context.tempdecls);
		env.context.temps := (m,tinit)::!(env.context.temps);
		set_tempname m namevar;
		set_sharedpretty m (ref None)
	end

let new_lambda env rettyp args body envvars =
	let num = next_lambda_num env in
	let basename = env.naminginfo.basename ^ num in
	let funname = new_namevar ("_ff_" ^ basename) "_ff" in
	let strname = new_namevar ("_fe" ^ basename) "_fe" in
	let m = env.meta in
	let vars = SH.to_list envvars in
	let envname = if vars <> [] then 
			new_namevar ("_ft_" ^ basename) "_ft"
		else new_fixed_namevar "_DNULL" in
	let funinfo = {funname = funname; envname = envname; 
						strname = strname; envvars = vars} in
	env.context.lamenvs := funinfo::!(env.context.lamenvs);
	set_localfun_info m funinfo;
	set_return_type m rettyp; 
	env.context.lambdas := (m,args,body)::!(env.context.lambdas);
	set_sharedpretty m (ref None)



(* ------------------------------------------------------------
 * Utils
 * ------------------------------------------------------------ *)

(* expand any typedefs *)		
let rec resolve_typ env (specs,coretyp,declty as typ) = match coretyp with
	| TyName id ->
		let specs2,coretyp2,declty2 = find_typedef env id in
		let newty = (insert_specs specs specs2, coretyp2, declty @ declty2) in
		resolve_typ env newty 
	| _ -> typ

let tc_iter2 env what f xs ys =
	if List.length xs <> List.length ys then
		tyerr env (str "wrong number of" <++> str what)
	else
		List.iter2 f xs ys
	
