/*
Copyright (C) 2011-2013 Vanderbilt University

Permission is hereby granted, free of charge, to any person obtaining a
copy of this data, including any software or models in source or binary
form, as well as any drawings, specifications, and documentation
(collectively "the Data"), to deal in the Data without restriction,
including without limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of the Data, and to
permit persons to whom the Data is furnished to do so, subject to the
following conditions:

The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Data.

THE DATA IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS, SPONSORS, DEVELOPERS, CONTRIBUTORS, OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE DATA OR THE USE OR OTHER DEALINGS IN THE DATA.  
*/
#include "MdaoAssembly.h"

extern string OutputDir;

MdaoAssembly::MdaoAssembly(string SubDirName) : MdaoBase(SubDirName)
{
	m_isWorkFlowValid = false;
	BaseClasses.push_back("Assembly");
	MdaoBase::InitBaseClasses();
	InitImports();
}

MdaoAssembly::~MdaoAssembly(void)
{
	for (set<MdaoComponent *>::iterator it = Bases.begin();
		it != Bases.end();
		++it)
	{
		delete *it;
	}

	for (set<MdaoWire *>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		delete *it;
	}
}

void MdaoAssembly::PrintLoadAll()
{
	string fileName = OutputDir + RelPath + "\\mdaoLoadAllPrototypes.cmd";
	string fileNameAsm = OutputDir + RelPath + "\\mdaoLoadAsm.cmd";
	ofstream fAsm(fileNameAsm.c_str());
	ofstream fPrototypes(fileName.c_str());
	stringstream f;
	f << ":: ====================================================================";
	f << endl;
	f << ":: Run this file in order to test the functionality of the ";
	f << endl;
	f << ":: generated OpenMDAO python components.";
	f << endl;
	f << ":: Prerequisites: TODO: ...";
	f << endl;
	f << ":: Set the openmadao root folder in your environment vars. TODO: ...";
	f << endl;
	f << ":: ====================================================================";
	f << endl;
	f << "@echo off";
	f << endl;
	f << "set VIRTUAL_ENV=C:\\Users\\rwg7487\\Desktop\\META\\OpenMDAO\\openmdao-0.1.8";
	f << endl;
	f << endl;
	f << "if not defined PROMPT (";
	f << endl;
	f << "    set PROMPT=$P$G";
	f << endl;
	f << ")";
	f << endl;
	f << endl;
	f << "if defined _OLD_VIRTUAL_PROMPT (";
	f << endl;
	f << "    set PROMPT=%_OLD_VIRTUAL_PROMPT%";
	f << endl;
	f << ")";
	f << endl;
	f << endl;
	f << "set _OLD_VIRTUAL_PROMPT=%PROMPT%";
	f << endl;
	f << "set PROMPT=(openmdao-0.1.8) %PROMPT%";
	f << endl;
	f << endl;
	f << "if defined _OLD_VIRTUAL_PATH set PATH=%_OLD_VIRTUAL_PATH%; goto SKIPPATH";
	f << endl;
	f << endl;
	f << "set _OLD_VIRTUAL_PATH=%PATH%";
	f << endl;
	f << endl;
	f << ":SKIPPATH";
	f << endl;
	f << "set PATH=%VIRTUAL_ENV%\\Scripts;%PATH%";
	f << endl;
	f << endl;
	f << ":END";
	f << endl;
	f << endl;

	fPrototypes << f.str();
	fAsm << f.str(); // print the same header
	fAsm << "python " + Name + ".py";
	fAsm << endl;

	for (vector<MdaoComponent *>::const_iterator i = m_workflowComp.begin();
		i != m_workflowComp.end();
		++i)
	{
		fPrototypes << "python " + (*i)->Name + ".py";
		fPrototypes << endl;
	}
	fPrototypes << endl;
	fPrototypes << "pause";
	fPrototypes << endl;
	fPrototypes.close();

	fAsm << endl;
	fAsm << "pause";
	fAsm << endl;
	fAsm.close();

}

void MdaoAssembly::InitImports()
{
	MdaoBase::InitImports();
	Imports.insert(make_pair("Assembly", "openmdao.main.api"));
	
	Imports.insert(make_pair("CONMINdriver", "openmdao.lib.drivers.api"));
	Imports.insert(make_pair("set_as_top", "openmdao.main.api"));
}

void MdaoAssembly::RefreshWires()
{
	for (set<MdaoWire*>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		for (set<MdaoComponent*>::iterator itComp = Bases.begin();
			itComp != Bases.end();
			++itComp)
		{
			if ((*itComp)->InstanceName == (*it)->Src)
			{
				(*it)->SrcComp = *itComp;
			}
			else if ((*itComp)->InstanceName == (*it)->Dst)
			{
				(*it)->DstComp = *itComp;
			}
		}
	}
}

void MdaoAssembly::PrintPython()
{
	MdaoBase::PrintPython();

	py.CreateComment("OpenMDAO Component Assembly");

	// import components
	for (set<MdaoComponent*>::iterator it = Bases.begin();
		it != Bases.end();
		++it)
	{
		Imports.insert(make_pair((*it)->Name, (*it)->Name));
	}

	// import classes
	PrintImports();
	py.CreateImport("os");
	py.CreateImport("time");

	py.CreateEmptyLine();

	// class definition
	PrintClassDef();

	py.CreateComment("set up interface to the framework", 2);
	RefreshWires();
	for (set<MdaoWire*>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		if ((*it)->Type == MdaoWire::Input)
		{
			Inputs.insert(make_pair<string, string>((*it)->SrcPortName, "0.0"));
			//Inputs.insert(make_pair<string, string>((*it)->DstPortName, "0.0"));
		}
	}

	PrintInputs();

	PrintOutputs();

	// create init function
	py.CreateFunction("__init__");
	// call super
	py.CreateCodeLine("super(" + Name + ", self).__init__()");

	CreateComponentInstances();

	// WorkFlow definition
	CreateWorkflow();
	PrintWorkflow();

	PrintWires();

	py.CloseSection("__init__");
	
	PrintDriver();
	
	py.CloseSection("class" + Name);
	// print main function
	py.CreateCodeLine("if __name__ == \"__main__\":", true);
	py.CreateComment("individualComponentTest()");

	py.CreateComment("xx = testAssy()");

	py.CreateCodeLine("xx = " + Name + "()");
	py.CreateCodeLine("xx.instanceDriver()");
	py.CreateCodeLine("xx.formProblem()");
	py.CreateCodeLine("xx.runConMin()");
	py.CreateComment("del(xx)");
	py.CreateComment("os._exit(1)");

	py.CloseSection(Name + ".py");
}

void MdaoAssembly::PrintWires()
{
	py.CreateEmptyLine();
	// print connections
	py.CreateComment("----------Wiring the components together-----------", 2);

	// print input connections
	py.CreateComment("Input Connections");
	for (set<MdaoWire*>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		if ((*it)->Type == MdaoWire::Input)
		{
			py.CreateCodeLine("self.connect('" \
				+ (*it)->SrcPortName \
				+ "', '"             \
				+ (*it)->Dst         \
				+ "."                \
				+ (*it)->DstPortName \
				+ "')");
		}
	}
	for (set<MdaoWire*>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		if ((*it)->Type == MdaoWire::Input)
		{
			py.CreateComment("self.passthrough('" \
				+ (*it)->Dst         \
				+ "."                \
				+ (*it)->DstPortName \
				+ "')");
		}
	}
	// print internal connections
	py.CreateComment("Internal Connections");
	for (set<MdaoWire*>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		if ((*it)->Type == MdaoWire::Internal)
		{
			py.CreateCodeLine("self.connect('" \
				+ (*it)->Src         \
				+ "."                \
				+ (*it)->SrcPortName \
				+ "', '"             \
				+ (*it)->Dst         \
				+ "."                \
				+ (*it)->DstPortName \
				+ "')");
		}
	}
	// print output connections
	py.CreateComment("Output Connections");
	for (set<MdaoWire*>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		if ((*it)->Type == MdaoWire::Output)
		{
			py.CreateCodeLine("self.create_passthrough('" \
				+ (*it)->Src         \
				+ "."                \
				+ (*it)->SrcPortName \
				+ "')");
		}
	}
}

void MdaoAssembly::PrintComponents()
{
	for (set<MdaoComponent *>::iterator it = Bases.begin();
		it != Bases.end();
		++it)
	{
		(*it)->PrintPython();
	}
}

void MdaoAssembly::CreateComponentInstances()
{
	py.CreateEmptyLine();
	// create component instances
	py.CreateComment("Create component instances", 2);
	for (set<MdaoComponent*>::iterator it = Bases.begin();
		it != Bases.end();
		++it)
	{
		m_CopyBases.insert(*it);
		string instance = "self.add('" \
			+ (*it)->InstanceName \
			+ "', "               \
			+ (*it)->Name         \
			+ "(";
		if ((*it)->Type == MdaoComponent::ComponentType::MatLab)
		{
			instance += "r\"";
			instance += (*it)->TargetFileName.substr(0, (*it)->TargetFileName.length() - 1);
			instance += "\"";
		}
		instance += "))";

		py.CreateCodeLine(instance);
	}
}

void MdaoAssembly::PrintDriver()
{
	py.CreateEmptyLine();
	// create optimizer
	py.CreateFunction("instanceDriver");
	py.CreateComment("$Static");
	py.CreateCodeLine("self.add('driver', CONMINdriver())");
	py.CreateCodeLine("self.driver.iprint = 0");
	py.CreateCodeLine("self.driver.itmax = 300");
	py.CreateCodeLine("self.driver.fdch = .000001");
	py.CreateCodeLine("self.driver.fdchm = .000001");
	py.CloseSection("instanceDriver");

	// TODO This region
	py.CreateFunction("formProblem");
	py.CreateComment("$GENERIC based on optimizer objective component definition");
	//py.CreateComment("self.driver.add_objective('prot_Excel.out_sinOfSum')");

	for (set<string>::iterator it = Objectives.begin();
		it != Objectives.end();
		++it)
	{
		py.CreateCodeLine("self.driver.add_objective('" + *it + "')");
	}


	py.CreateComment("#$GENERIC based on optimizer design variable component definition(s)");
	//py.CreateComment("self.driver.add_parameter('bp_a', low=-1.0, high=1.0)");
	//py.CreateComment("self.driver.add_parameter('bp_c', low=-5.0, high=5.0)");
	for (map<string, pair<string, string>>::iterator it = DesignVariables.begin();
		it != DesignVariables.end();
		++it)
	{
		py.CreateCodeLine("self.driver.add_parameter('" +
			it->first +
			"', low=" + it->second.first +
			", high=" + it->second.second + ")");
	}

	for (map<string, pair<string, string>>::iterator it = DesignVariables.begin();
		it != DesignVariables.end();
		++it)
	{
		py.CreateCodeLine("self." + it->first + " = " + it->second.first);
	}

	//py.CreateComment("self.bp_a = 0.1");
	//py.CreateComment("self.bp_c = 1.9");

	py.CreateComment("#$GENERIC based on TBD optimizer constraint variable component definition(s)");
	//py.CreateComment("self.driver.add_constraint('prot_ConCollector.alpha <=11.5')");
	//py.CreateComment("self.driver.add_constraint('prot_ConCollector.alpha >=-3.5')");
	//py.CreateComment("self.driver.add_constraint('prot_ConCollector.beta <=10.5')");
	//py.CreateComment("self.driver.add_constraint('prot_ConCollector.beta >=-0')");
	for (map<string, pair<string, string>>::iterator it = Constraints.begin();
		it != Constraints.end();
		++it)
	{
		py.CreateComment(it->second.first + " <= " + it->first + " <= " + it->second.second);
		py.CreateCodeLine("self.driver.add_constraint('" + it->first + " >= " + it->second.first + "')");
		py.CreateCodeLine("self.driver.add_constraint('" + it->first + " <= " + it->second.second + "')");
	}
	py.CloseSection("formProblem");

	py.CreateFunction("runConMin");
	py.CreateComment("$Static");
	py.CreateCodeLine("self.driver.iprint = 1");
	py.CreateCodeLine("self.driver.itmax = 300");
	py.CreateCodeLine("set_as_top(self)");

	py.CreateCodeLine("tt = time.time()");
	py.CreateCodeLine("self.run()");
	py.CreateCodeLine("print \"\\n\"");
	py.CreateCodeLine("print \"CONMIN Iterations: \", self.driver.iter_count");
	py.CreateComment("$GENERIC");
	// TODO This region
	py.CreateComment("print \"Minimum found at (%f, %f)\" % (self.bp_a, \\");
	py.CreateComment("self.bp_c)");

	string FirstLine = "print \"Minimum found at (";
	for (int i = 0; i < DesignVariables.size(); ++i)
	{
		FirstLine += "%f, ";
	}
	if (DesignVariables.size() > 0)
	{
		FirstLine = FirstLine.substr(0, FirstLine.size() - 2);
	}
	FirstLine += ")\" % (\\";

	py.CreateCodeLine(FirstLine);

	for (map<string, pair<string, string>>::iterator it = DesignVariables.begin();
		it != DesignVariables.end();
		++it)
	{
		string Code = "self." + it->first;
	
		if (++it == DesignVariables.end())
		{
			// last element
			Code += ")";
		}
		else
		{
			// not the last element
			Code += ", \\";
		}
		--it;

			py.CreateCodeLine(Code);
	}

	py.CreateComment("$Static");
	py.CreateCodeLine("print \"Elapsed time: \", time.time()-tt, \"seconds\"");
	py.CloseSection("runConMin");
}

void MdaoAssembly::PrintWorkflow()
{
	py.CreateComment("----------Defining the work flow-----------", 2);

	string workflow = ""; // debug string
	if (!m_isWorkFlowValid)
	{
		// print workflow
		py.CreateComment("TODO: Double check workflow order");
		py.CreateCodeLine("self.driver.workflow.add([ \\", true);
		string Code = "";
		for (vector<MdaoComponent*>::iterator it = m_workflowComp.begin();
			it != m_workflowComp.end();
			++it)
		{
			Code = "";
			Code += "'" + (*it)->InstanceName + "'";
			if (++it != m_workflowComp.end())
			{
				Code += ",";
			}
			--it;
			Code += " \\";
			py.CreateCodeLine(Code);
			workflow += (*it)->InstanceName + ", ";
		}
		py.CreateCodeLine("])");
		py.CloseSection();
	}
	else
	{
		py.CreateComment("ERROR: Could not be determine the workflow based on the connections!");
	}
}

void MdaoAssembly::CreateWorkflow()
{
	// detect workflow
	set<MdaoWire*> PossibleInputs;
	set<MdaoWire*> ProcessWires;
	for (set<MdaoWire*>::iterator it = Wires.begin();
		it != Wires.end();
		++it)
	{
		if ((*it)->Type == MdaoWire::Input)
		{
			// first pass
			PossibleInputs.insert(*it);
		}
	}

	while (!m_CopyBases.empty() && !m_isWorkFlowValid)
	{
		unsigned int currentSize = m_CopyBases.size();
		for (set<MdaoComponent*>::iterator itBases = Bases.begin();
			itBases != Bases.end();
			++itBases)
		{
			int countPossible = 0;
			int countAll = 0;
			for (set<MdaoWire*>::iterator it = PossibleInputs.begin();
				it != PossibleInputs.end();
				++it)
			{
				if ((*it)->DstComp == *itBases)
				{
					countPossible++;
				}
			}
			for (set<MdaoWire*>::iterator it = Wires.begin();
				it != Wires.end();
				++it)
			{
				if ((*it)->DstComp == *itBases)
				{
					countAll++;
				}
			}
			if (countPossible == countAll)
			{
				if (m_CopyBases.erase(*itBases) == 1)
				{
					m_workflowComp.push_back(*itBases);
				}
				for (set<MdaoWire*>::iterator it = Wires.begin();
					it != Wires.end();
					++it)
				{
					if ((*it)->SrcComp == *itBases)
					{
						PossibleInputs.insert(*it);
					}
				}
			}
		}
		if (currentSize <= m_CopyBases.size())
		{
			// number of the items in CopyBases should decrease
			m_isWorkFlowValid = true;
		}
	}
}
