/*
 * Copyright © 2003-2024 Dynare Team
 *
 * This file is part of Dynare.
 *
 * Dynare is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Dynare is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Dynare.  If not, see <https://www.gnu.org/licenses/>.
 */

#ifndef STATIC_MODEL_HH
#define STATIC_MODEL_HH

#include <filesystem>
#include <fstream>

#include "Bytecode.hh"
#include "ModelTree.hh"

using namespace std;

class DynamicModel;

//! Stores a static model, as derived from the "model" block when leads and lags have been removed
class StaticModel : public ModelTree
{
private:
  /* First-order derivatives of equations w.r.t. Lagrange multipliers, using
     chain rule derivation for auxiliary variables added after the multipliers
     (so that derivatives of optimality FOCs w.r.t. multipliers with lead or
     lag ⩾ 2 are self-contained, which is required by dyn_ramsey_static.m).
     Only used if 'ramsey_model' or 'ramsey_policy' is present.
     The first index of the key is the equation number (NB: auxiliary equations
     added after the multipliers do not appear).
     The second index is the index of the Lagrange multiplier (ordered by
     increasing symbol ID) */
  SparseColumnMajorOrderMatrix ramsey_multipliers_derivatives;
  /* Column indices for the derivatives w.r.t. Lagrange multipliers in
     Compressed Sparse Column (CSC) storage (corresponds to the “jc” vector in
     MATLAB terminology) */
  vector<int> ramsey_multipliers_derivatives_sparse_colptr;
  // Temporary terms for ramsey_multipliers_derivatives
  temporary_terms_t ramsey_multipliers_derivatives_temporary_terms;
  // Stores, for each temporary term, its index in the MATLAB/Octave vector
  temporary_terms_idxs_t ramsey_multipliers_derivatives_temporary_terms_idxs;

  /* Value of the “static_mfs” option of “model” block (or the “model_options”
     command).
     NB: the default value defined here is not used when converting from
     DynamicModel class, and in particular it does not affect the main “model”
     block. See the DynamicModel class for the default value in that case. */
  int static_mfs {0};

  //! Writes the code of the block-decomposed model in virtual machine bytecode
  void writeStaticBlockBytecode(const string& basename) const;

  //! Writes the code of the model in virtual machine bytecode
  void writeStaticBytecode(const string& basename) const;

  //! Computes jacobian and prepares for equation normalization
  /*! Using values from initval/endval blocks and parameter initializations:
    - computes the jacobian for the model w.r. to contemporaneous variables
    - removes edges of the incidence matrix when derivative w.r. to the corresponding variable is
    too close to zero (below the cutoff)
  */
  void evaluateJacobian(const eval_context_t& eval_context, jacob_map_t* j_m, bool dynamic);

  SymbolType getTypeByDerivID(int deriv_id) const noexcept(false) override;
  int getLagByDerivID(int deriv_id) const noexcept(false) override;
  int getSymbIDByDerivID(int deriv_id) const noexcept(false) override;
  int getTypeSpecificIDByDerivID(int deriv_id) const override;

  int
  getJacobianCol(int deriv_id, [[maybe_unused]] bool sparse) const override
  {
    return getTypeSpecificIDByDerivID(deriv_id);
  }
  int
  getJacobianColsNbr([[maybe_unused]] bool sparse) const override
  {
    return symbol_table.endo_nbr();
  }

  void computeChainRuleJacobian() override;

  int
  getBlockJacobianEndoCol([[maybe_unused]] int blk, int var,
                          [[maybe_unused]] int lag) const override
  {
    assert(var >= blocks[blk].getRecursiveSize());
    return var - blocks[blk].getRecursiveSize();
  }

  // Write the block structure of the model in the driver file
  void writeBlockDriverOutput(ostream& output) const;

  // Helper for writing ramsey_multipliers_derivatives
  template<ExprNodeOutputType output_type>
  void writeRamseyMultipliersDerivativesHelper(ostream& output) const;

  //! Internal helper for the copy constructor and assignment operator
  /*! Copies all the structures that contain ExprNode*, by the converting the
      pointers into their equivalent in the new tree */
  void copyHelper(const StaticModel& m);

protected:
  string
  modelClassName() const override
  {
    return "static model";
  }

public:
  StaticModel(SymbolTable& symbol_table_arg, NumericalConstants& num_constants,
              ExternalFunctionsTable& external_functions_table_arg,
              HeterogeneityTable& heterogeneity_table_arg);

  StaticModel(const StaticModel& m);
  StaticModel& operator=(const StaticModel& m);

  //! Creates the static version of a dynamic model
  explicit StaticModel(const DynamicModel& m);

  //! Writes information about the static model to the driver file
  void writeDriverOutput(ostream& output) const;

  //! Execute computations (variable sorting + derivation + block decomposition)
  /*!
    \param eval_context evaluation context for normalization
    \param no_tmp_terms if true, no temporary terms will be computed in the static files
    \param derivsOrder order of derivation with respect to endogenous
    \param paramsDerivsOrder order of derivatives w.r. to a pair (endogenous, parameter) to be
    computed
  */
  void computingPass(int derivsOrder, int paramsDerivsOrder, const eval_context_t& eval_context,
                     bool no_tmp_terms, bool block, bool use_dll);

  //! Writes static model file (+ bytecode)
  void writeStaticFile(const string& basename, bool use_dll, const string& mexext,
                       const filesystem::path& matlabroot, bool julia) const;

  //! Write JSON Output (used by PlannerObjectiveStatement)
  void writeJsonOutput(ostream& output) const;

  //! Write JSON representation of static model
  void writeJsonComputingPassOutput(ostream& output, bool writeDetails) const;

  //! Write JSON params derivatives
  void writeJsonParamsDerivatives(ostream& output, bool writeDetails) const;

  //! Writes file containing static parameters derivatives
  template<bool julia>
  void writeParamsDerivativesFile(const string& basename) const;

  //! Writes LaTeX file with the equations of the static model
  void writeLatexFile(const string& basename, bool write_equation_tags) const;

  //! Writes initializations in oo_.steady_state or steady state file for the auxiliary variables
  void writeAuxVarInitval(ostream& output, ExprNodeOutputType output_type) const;

  void writeLatexAuxVarRecursiveDefinitions(ostream& output) const;
  void writeJsonAuxVarRecursiveDefinitions(ostream& output) const;

  //! To ensure that no exogenous is present in the planner objective
  //! See #1264
  bool exoPresentInEqs() const;

  int getDerivID(int symb_id, int lag) const noexcept(false) override;
  void addAllParamDerivId(set<int>& deriv_id_set) override;

  // Fills the ramsey_multipliers_derivatives structure (see the comment there)
  void computeRamseyMultipliersDerivatives(int ramsey_orig_endo_nbr, bool is_matlab,
                                           bool no_tmp_terms);

  // Writes the sparse indices of ramsey_multipliers_derivatives to the driver file
  void writeDriverRamseyMultipliersDerivativesSparseIndices(ostream& output) const;

  // Writes ramsey_multipliers_derivatives (MATLAB/Octave version)
  void writeRamseyMultipliersDerivativesMFile(const string& basename,
                                              int ramsey_orig_endo_nbr) const;

  // Writes ramsey_multipliers_derivatives (C version)
  void writeRamseyMultipliersDerivativesCFile(const string& basename, const string& mexext,
                                              const filesystem::path& matlabroot,
                                              int ramsey_orig_endo_nbr) const;

  int
  getMFS() const override
  {
    return static_mfs;
  }
};

template<bool julia>
void
StaticModel::writeParamsDerivativesFile(const string& basename) const
{
  if (!params_derivatives.size())
    return;

  constexpr ExprNodeOutputType output_type {julia ? ExprNodeOutputType::juliaSparseStaticModel
                                                  : ExprNodeOutputType::matlabSparseStaticModel};

  auto [tt_output, rp_output, g1p_output, rpp_output, g1pp_output, g2p_output,
        g3p_output] {writeParamsDerivativesFileHelper<output_type>()};
  // g3p_output is ignored

  if constexpr (!julia)
    {
      filesystem::path filename {packageDir(basename) / "static_params_derivs.m"};
      ofstream paramsDerivsFile {filename, ios::out | ios::binary};
      if (!paramsDerivsFile.is_open())
        {
          cerr << "ERROR: Can't open file " << filename.string() << " for writing" << endl;
          exit(EXIT_FAILURE);
        }
      paramsDerivsFile
          << "function [rp, g1p, rpp, g1pp, g2p] = static_params_derivs(y, x, params)" << endl
          << "%" << endl
          << "% Status : Computes derivatives of the static model with respect to the parameters"
          << endl
          << "%" << endl
          << "% Inputs : " << endl
          << "%   y         [M_.endo_nbr by 1] double    vector of endogenous variables in "
             "declaration order"
          << endl
          << "%   x         [M_.exo_nbr by 1] double     vector of exogenous variables in "
             "declaration order"
          << endl
          << "%   params    [M_.param_nbr by 1] double   vector of parameter values in declaration "
             "order"
          << endl
          << "%" << endl
          << "% Outputs:" << endl
          << "%   rp        [M_.eq_nbr by #params] double    Jacobian matrix of static model "
             "equations with respect to parameters "
          << endl
          << "%                                              Dynare may prepend or append "
             "auxiliary equations, see M_.aux_vars"
          << endl
          << "%   g1p       [#first_order_Jacobian_terms by 4] double    Derivative of the "
             "Jacobian matrix of the static model equations with respect to the parameters"
          << endl
          << "%                                                              rows: respective "
             "derivative term"
          << endl
          << "%                                                              1st column: equation "
             "number of the term appearing"
          << endl
          << "%                                                              2nd column: number of "
             "the variable in derivative"
          << endl
          << "%                                                              3rd column: number of "
             "the parameter in derivative"
          << endl
          << "%                                                              4th column: value of "
             "the derivative term"
          << endl
          << "%   rpp       [#second_order_residual_terms by 4] double   Hessian matrix of second "
             "derivatives of residuals with respect to parameters;"
          << endl
          << "%                                                              rows: respective "
             "derivative term"
          << endl
          << "%                                                              1st column: equation "
             "number of the term appearing"
          << endl
          << "%                                                              2nd column: number of "
             "the first parameter in derivative"
          << endl
          << "%                                                              3rd column: number of "
             "the second parameter in derivative"
          << endl
          << "%                                                              4th column: value of "
             "the Hessian term"
          << endl
          << "%   g1pp     [#second_order_Jacobian_terms by 5] double   Hessian matrix of second "
             "derivatives of the Jacobian with respect to the parameters;"
          << endl
          << "%                                                              rows: respective "
             "derivative term"
          << endl
          << "%                                                              1st column: equation "
             "number of the term appearing"
          << endl
          << "%                                                              2nd column: column "
             "number of variable in Jacobian of the static model"
          << endl
          << "%                                                              3rd column: number of "
             "the first parameter in derivative"
          << endl
          << "%                                                              4th column: number of "
             "the second parameter in derivative"
          << endl
          << "%                                                              5th column: value of "
             "the Hessian term"
          << endl
          << "%   g2p     [#first_order_Hessian_terms by 5] double   Jacobian matrix of "
             "derivatives of the static Hessian with respect to the parameters;"
          << endl
          << "%                                                              rows: respective "
             "derivative term"
          << endl
          << "%                                                              1st column: equation "
             "number of the term appearing"
          << endl
          << "%                                                              2nd column: column "
             "number of first variable in Hessian of the static model"
          << endl
          << "%                                                              3rd column: column "
             "number of second variable in Hessian of the static model"
          << endl
          << "%                                                              4th column: number of "
             "the parameter in derivative"
          << endl
          << "%                                                              5th column: value of "
             "the Hessian term"
          << endl
          << "%" << endl
          << "%" << endl
          << "% Warning : this file is generated automatically by Dynare" << endl
          << "%           from model file (.mod)" << endl
          << endl
          << "T = NaN(" << params_derivs_temporary_terms_idxs.size() << ",1);" << endl
          << tt_output.str() << "rp_i = NaN(" << params_derivatives.at({0, 1}).size() << ", 1);"
          << endl
          << "rp_j = NaN(" << params_derivatives.at({0, 1}).size() << ", 1);" << endl
          << "rp_v = NaN(" << params_derivatives.at({0, 1}).size() << ", 1);" << endl
          << rp_output.str() << "rp = sparse(rp_i, rp_j, rp_v, " << equations.size() << ", "
          << symbol_table.param_nbr() << ");" << endl
          << "g1p = NaN(" << params_derivatives.at({1, 1}).size() << ",4);" << endl
          << g1p_output.str() << "if nargout >= 3" << endl
          << "rpp = NaN(" << params_derivatives.at({0, 2}).size() << ",4);" << endl
          << rpp_output.str() << "g1pp = NaN(" << params_derivatives.at({1, 2}).size() << ",5);"
          << endl
          << g1pp_output.str() << "end" << endl
          << "if nargout >= 5" << endl
          << "g2p = NaN(" << params_derivatives.at({2, 1}).size() << ",5);" << endl
          << g2p_output.str() << "end" << endl
          << "end" << endl;
      paramsDerivsFile.close();
    }
  else
    {
      stringstream output;
      output << "# NB: this file was automatically generated by Dynare" << endl
             << "#     from " << basename << ".mod" << endl
             << "#" << endl
             << "function static_params_derivs(y, x, params)" << endl
             << "@inbounds begin" << endl
             << tt_output.str() << "rp_i = fill(NaN, " << params_derivatives.at({0, 1}).size()
             << ");" << endl
             << "rp_j = fill(NaN, " << params_derivatives.at({0, 1}).size() << ");" << endl
             << "rp_v = fill(NaN, " << params_derivatives.at({0, 1}).size() << ");" << endl
             << rp_output.str() << "rp = sparse(rp_i, rp_j, rp_v, " << equations.size() << ", "
             << symbol_table.param_nbr() << ");" << endl
             << "g1p = fill(NaN, " << params_derivatives.at({1, 1}).size() << ",4);" << endl
             << g1p_output.str() << "rpp = fill(NaN, " << params_derivatives.at({0, 2}).size()
             << ",4);" << endl
             << rpp_output.str() << "g1pp = fill(NaN, " << params_derivatives.at({1, 2}).size()
             << ",5);" << endl
             << g1pp_output.str() << "g2p = fill(NaN, " << params_derivatives.at({2, 1}).size()
             << ",5);" << endl
             << g2p_output.str() << "end" << endl
             << "return (rp, g1p, rpp, g1pp, g2p)" << endl
             << "end" << endl;

      writeToFileIfModified(output, filesystem::path {basename} / "model" / "julia"
                                        / "StaticParamsDerivs.jl");
    }
}

template<ExprNodeOutputType output_type>
void
StaticModel::writeRamseyMultipliersDerivativesHelper(ostream& output) const
{
  // Write temporary terms (which includes external function stuff)
  deriv_node_temp_terms_t tef_terms;
  temporary_terms_t unused_tt_copy;
  writeTemporaryTerms<output_type>(ramsey_multipliers_derivatives_temporary_terms, unused_tt_copy,
                                   ramsey_multipliers_derivatives_temporary_terms_idxs, output,
                                   tef_terms);

  // Write chain rule derivatives
  for (int k {0}; auto& [row_col, d] : ramsey_multipliers_derivatives)
    {
      output << "g1m_v" << LEFT_ARRAY_SUBSCRIPT(output_type)
             << k + ARRAY_SUBSCRIPT_OFFSET(output_type) << RIGHT_ARRAY_SUBSCRIPT(output_type)
             << "=";
      d->writeOutput(output, output_type, ramsey_multipliers_derivatives_temporary_terms,
                     ramsey_multipliers_derivatives_temporary_terms_idxs, tef_terms);
      output << ";" << endl;
      k++;
    }
}

#endif
