/*
 * Copyright © 2019-2025 Dynare Team
 *
 * This file is part of Dynare.
 *
 * Dynare is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * Dynare is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Dynare.  If not, see <https://www.gnu.org/licenses/>.
 */

#include "DynamicModelCaller.hh"

#include <algorithm>
#include <array>
#include <filesystem>

using namespace std::literals::string_literals;

std::string DynamicModelCaller::error_msg;
std::string DynamicModelCaller::error_id;
std::mutex DynamicModelCaller::error_mtx;

void
DynamicModelCaller::setErrMsg(std::string msg)
{
  std::lock_guard lk {error_mtx};
  error_msg = move(msg);
  error_id.clear();
}

void
DynamicModelCaller::setMException(const mxArray* exception)
{
  const mxArray* message_mx {nullptr};
  const mxArray* identifier_mx {nullptr};
  if (mxIsClass(exception, "MException"))
    {
      message_mx = mxGetProperty(exception, 0, "message");
      identifier_mx = mxGetProperty(exception, 0, "identifier");
    }
  else if (mxIsStruct(exception)) // For Octave
    {
      message_mx = mxGetField(exception, 0, "message");
      identifier_mx = mxGetField(exception, 0, "identifier");
    }
  else
    mexErrMsgTxt("Exception of incorrect type");

  if (!message_mx || !mxIsChar(message_mx) || !identifier_mx || !mxIsChar(identifier_mx))
    mexErrMsgTxt("Exception object malformed");

  char* message = mxArrayToString(message_mx);
  char* identifier = mxArrayToString(identifier_mx);
  std::lock_guard lk {error_mtx};
  error_msg = message;
  error_id = identifier;
  mxFree(message);
  mxFree(identifier);
}

DynamicModelDllCaller::DynamicModelDllCaller(size_t ntt, mwIndex ny, mwIndex nx,
                                             const double* params_arg,
                                             const double* steady_state_arg,
                                             const int32_T* g1_sparse_colptr_arg, bool linear_arg,
                                             bool compute_jacobian_arg) :
    DynamicModelCaller {linear_arg, compute_jacobian_arg},
    params {params_arg},
    steady_state {steady_state_arg},
    g1_sparse_colptr {g1_sparse_colptr_arg}
{
  tt.resize(ntt);
  y_p.resize(3 * ny);
  x_p.resize(nx);
}

void
DynamicModelDllCaller::copy_jacobian_column(mwIndex col, double* dest) const
{
  std::ranges::copy_n(jacobian_p.data() + g1_sparse_colptr[col] - 1,
                      g1_sparse_colptr[col + 1] - g1_sparse_colptr[col], dest);
}

DynamicModelMatlabCaller::DynamicModelMatlabCaller(std::string basename_arg, mwIndex ny, mwIndex nx,
                                                   const mxArray* params_mx_arg,
                                                   const mxArray* steady_state_mx_arg,
                                                   const mxArray* g1_sparse_rowval_mx_arg,
                                                   const mxArray* g1_sparse_colval_mx_arg,
                                                   const mxArray* g1_sparse_colptr_mx_arg,
                                                   bool linear_arg, bool compute_jacobian_arg) :
    DynamicModelCaller {linear_arg, compute_jacobian_arg},
    basename {std::move(basename_arg)},
    y_mx {mxCreateDoubleMatrix(3 * ny, 1, mxREAL)},
    x_mx {mxCreateDoubleMatrix(nx, 1, mxREAL)},
    jacobian_mx {nullptr},
    params_mx {mxDuplicateArray(params_mx_arg)},
    steady_state_mx {mxDuplicateArray(steady_state_mx_arg)},
    g1_sparse_rowval_mx {mxDuplicateArray(g1_sparse_rowval_mx_arg)},
    g1_sparse_colval_mx {mxDuplicateArray(g1_sparse_colval_mx_arg)},
    g1_sparse_colptr_mx {mxDuplicateArray(g1_sparse_colptr_mx_arg)}
{
}

DynamicModelMatlabCaller::~DynamicModelMatlabCaller()
{
  mxDestroyArray(y_mx);
  mxDestroyArray(x_mx);
  if (jacobian_mx)
    mxDestroyArray(jacobian_mx);
  mxDestroyArray(params_mx);
  mxDestroyArray(steady_state_mx);
  mxDestroyArray(g1_sparse_rowval_mx);
  mxDestroyArray(g1_sparse_colval_mx);
  mxDestroyArray(g1_sparse_colptr_mx);
}

void
DynamicModelMatlabCaller::copy_jacobian_column(mwIndex col, double* dest) const
{
  if (jacobian_mx)
    {
      const int32_T* g1_sparse_rowval {mxGetInt32s(g1_sparse_rowval_mx)};
      const int32_T* g1_sparse_colptr {mxGetInt32s(g1_sparse_colptr_mx)};

      /* We cannot assume that jacobian_mx internally uses
         g1_sparse_{rowval,colval,colptr}, because the call to sparse() in
         dynamic_g1.m may have further compressed the matrix by removing
         elements that are numerically zero, despite being symbolically
         non-zero. */
      mwIndex *ir {mxGetIr(jacobian_mx)}, *jc {mxGetJc(jacobian_mx)};
      mwIndex isrc {jc[col]}; // Index in value array of source Jacobian
      for (mwIndex idest {0}; // Index in value array of destination Jacobian
           idest < static_cast<mwIndex>(g1_sparse_colptr[col + 1] - g1_sparse_colptr[col]); idest++)
        {
          mwIndex row {
              static_cast<mwIndex>(g1_sparse_rowval[idest + g1_sparse_colptr[col] - 1] - 1)};
          while (isrc < jc[col + 1] && ir[isrc] < row)
            isrc++;
          if (isrc < jc[col + 1] && ir[isrc] == row)
            dest[idest] = mxGetDoubles(jacobian_mx)[isrc];
          else
            dest[idest] = 0.0;
        }
    }
}

#if !defined(_WIN32) && !defined(__CYGWIN32__)
void* DynamicModelNoblockDllCaller::resid_mex {nullptr};
void* DynamicModelNoblockDllCaller::g1_mex {nullptr};
#else
HINSTANCE DynamicModelNoblockDllCaller::resid_mex {nullptr};
HINSTANCE DynamicModelNoblockDllCaller::g1_mex {nullptr};
#endif
DynamicModelNoblockDllCaller::dynamic_tt_fct DynamicModelNoblockDllCaller::residual_tt_fct {
    nullptr},
    DynamicModelNoblockDllCaller::g1_tt_fct {nullptr};
DynamicModelNoblockDllCaller::dynamic_fct DynamicModelNoblockDllCaller::residual_fct {nullptr},
    DynamicModelNoblockDllCaller::g1_fct {nullptr};

DynamicModelNoblockDllCaller::DynamicModelNoblockDllCaller(
    size_t ntt, mwIndex ny, mwIndex nx, const double* params_arg, const double* steady_state_arg,
    const int32_T* g1_sparse_colptr_arg, bool linear_arg, bool compute_jacobian_arg) :
    DynamicModelDllCaller {ntt,
                           ny,
                           nx,
                           params_arg,
                           steady_state_arg,
                           g1_sparse_colptr_arg,
                           linear_arg,
                           compute_jacobian_arg}
{
  if (compute_jacobian)
    jacobian_p.resize(g1_sparse_colptr[3 * ny + nx] - 1);
}

void
DynamicModelNoblockDllCaller::load_dll(const std::string& basename)
{
  // Load symbols from dynamic MEX
  const std::filesystem::path model_dir {"+" + basename};
  const std::filesystem::path resid_mex_name {model_dir / ("dynamic_resid"s + MEXEXT)},
      g1_mex_name {model_dir / ("dynamic_g1"s + MEXEXT)};
#if !defined(__CYGWIN32__) && !defined(_WIN32)
  resid_mex = dlopen(resid_mex_name.c_str(), RTLD_NOW);
  g1_mex = dlopen(g1_mex_name.c_str(), RTLD_NOW);
#else
  resid_mex = LoadLibraryW(resid_mex_name.c_str());
  g1_mex = LoadLibraryW(g1_mex_name.c_str());
#endif
  if (!resid_mex)
    mexErrMsgTxt("Can't load dynamic_resid MEX file");
  if (!g1_mex)
    mexErrMsgTxt("Can't load dynamic_g1 MEX file");

#if !defined(__CYGWIN32__) && !defined(_WIN32)
  residual_tt_fct = reinterpret_cast<dynamic_tt_fct>(dlsym(resid_mex, "dynamic_resid_tt"));
  residual_fct = reinterpret_cast<dynamic_fct>(dlsym(resid_mex, "dynamic_resid"));
  g1_tt_fct = reinterpret_cast<dynamic_tt_fct>(dlsym(g1_mex, "dynamic_g1_tt"));
  g1_fct = reinterpret_cast<dynamic_fct>(dlsym(g1_mex, "dynamic_g1"));
#else
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wcast-function-type"
  residual_tt_fct = reinterpret_cast<dynamic_tt_fct>(GetProcAddress(resid_mex, "dynamic_resid_tt"));
  residual_fct = reinterpret_cast<dynamic_fct>(GetProcAddress(resid_mex, "dynamic_resid"));
  g1_tt_fct = reinterpret_cast<dynamic_tt_fct>(GetProcAddress(g1_mex, "dynamic_g1_tt"));
  g1_fct = reinterpret_cast<dynamic_fct>(GetProcAddress(g1_mex, "dynamic_g1"));
# pragma GCC diagnostic pop
#endif
  if (!residual_tt_fct || !residual_fct || !g1_tt_fct || !g1_fct)
    mexErrMsgTxt("Can't load functions in dynamic MEX file");
}

void
DynamicModelNoblockDllCaller::unload_dll()
{
#if !defined(__CYGWIN32__) && !defined(_WIN32)
  dlclose(resid_mex);
  dlclose(g1_mex);
#else
  FreeLibrary(resid_mex);
  FreeLibrary(g1_mex);
#endif
}

void
DynamicModelNoblockDllCaller::eval(double* resid)
{
  residual_tt_fct(y_p.data(), x_p.data(), params, steady_state, tt.data());
  residual_fct(y_p.data(), x_p.data(), params, steady_state, tt.data(), resid);
  if (compute_jacobian)
    {
      g1_tt_fct(y_p.data(), x_p.data(), params, steady_state, tt.data());
      g1_fct(y_p.data(), x_p.data(), params, steady_state, tt.data(), jacobian_p.data());

      if (linear)
        compute_jacobian = false; // If model is linear, no need to recompute Jacobian later
    }
}

DynamicModelNoblockMatlabCaller::DynamicModelNoblockMatlabCaller(
    std::string basename_arg, mwIndex ny, mwIndex nx, const mxArray* params_mx_arg,
    const mxArray* steady_state_mx_arg, const mxArray* g1_sparse_rowval_mx_arg,
    const mxArray* g1_sparse_colval_mx_arg, const mxArray* g1_sparse_colptr_mx_arg, bool linear_arg,
    bool compute_jacobian_arg) :
    DynamicModelMatlabCaller {std::move(basename_arg),
                              ny,
                              nx,
                              params_mx_arg,
                              steady_state_mx_arg,
                              g1_sparse_rowval_mx_arg,
                              g1_sparse_colval_mx_arg,
                              g1_sparse_colptr_mx_arg,
                              linear_arg,
                              compute_jacobian_arg}
{
}

void
DynamicModelNoblockMatlabCaller::eval(double* resid)
{
  mxArray *T_order_mx, *T_mx;

  {
    // Compute residuals
    std::string funcname {basename + ".dynamic_resid"};
    std::array<mxArray*, 3> plhs;
    std::array prhs {y_mx, x_mx, params_mx, steady_state_mx};

    mxArray* exception {mexCallMATLABWithTrap(plhs.size(), plhs.data(), prhs.size(), prhs.data(),
                                              funcname.c_str())};
    if (exception)
      {
        setMException(exception);
        return; // Avoid manipulating null pointers in plhs, see #1832
      }

    if (!mxIsDouble(plhs[0]) || mxIsSparse(plhs[0]))
      {
        setErrMsg("Residuals should be a dense array of double floats");
        return;
      }

    if (mxIsComplex(plhs[0]))
      plhs[0] = cmplxToReal<false>(plhs[0]);

    std::ranges::copy_n(mxGetDoubles(plhs[0]), mxGetNumberOfElements(plhs[0]), resid);
    mxDestroyArray(plhs[0]);

    T_order_mx = plhs[1];
    T_mx = plhs[2];
  }

  if (compute_jacobian)
    {
      // Compute Jacobian
      std::string funcname {basename + ".dynamic_g1"};
      std::array<mxArray*, 1> plhs;
      std::array prhs {y_mx,
                       x_mx,
                       params_mx,
                       steady_state_mx,
                       g1_sparse_rowval_mx,
                       g1_sparse_colval_mx,
                       g1_sparse_colptr_mx,
                       T_order_mx,
                       T_mx};

      mxArray* exception {mexCallMATLABWithTrap(plhs.size(), plhs.data(), prhs.size(), prhs.data(),
                                                funcname.c_str())};
      if (exception)
        {
          setMException(exception);
          return; // Avoid manipulating null pointers in plhs, see #1832
        }

      if (jacobian_mx)
        {
          mxDestroyArray(jacobian_mx);
          jacobian_mx = nullptr;
        }

      if (!mxIsDouble(plhs[0]) || !mxIsSparse(plhs[0]))
        {
          setErrMsg("Jacobian should be a dense array of double floats");
          return;
        }

      if (mxIsComplex(plhs[0]))
        jacobian_mx = cmplxToReal<true>(plhs[0]);
      else
        jacobian_mx = plhs[0];

      if (linear)
        compute_jacobian = false; // If model is linear, no need to recompute Jacobian later
    }

  mxDestroyArray(T_order_mx);
  mxDestroyArray(T_mx);
}

#if !defined(_WIN32) && !defined(__CYGWIN32__)
void* DynamicModelBlockDllCaller::mex {nullptr};
#else
HINSTANCE DynamicModelBlockDllCaller::mex {nullptr};
#endif
DynamicModelBlockDllCaller::dynamic_resid_fct DynamicModelBlockDllCaller::residual_fct {nullptr};
DynamicModelBlockDllCaller::dynamic_g1_fct DynamicModelBlockDllCaller::g1_fct {nullptr};

void
DynamicModelBlockDllCaller::load_dll(const std::string& basename, int block_num)
{
  // Load symbols from dynamic MEX
  const std::filesystem::path model_dir {"+" + basename + "/+block"};
  const std::filesystem::path mex_name {model_dir
                                        / ("dynamic_"s + std::to_string(block_num) + MEXEXT)};
#if !defined(__CYGWIN32__) && !defined(_WIN32)
  mex = dlopen(mex_name.c_str(), RTLD_NOW);
#else
  mex = LoadLibraryW(mex_name.c_str());
#endif
  if (!mex)
    mexErrMsgTxt(("Can't load dynamic_"s + std::to_string(block_num) + " MEX file").c_str());

#if !defined(__CYGWIN32__) && !defined(_WIN32)
  residual_fct = reinterpret_cast<dynamic_resid_fct>(
      dlsym(mex, ("dynamic_"s + std::to_string(block_num) + "_resid").c_str()));
  g1_fct = reinterpret_cast<dynamic_g1_fct>(
      dlsym(mex, ("dynamic_"s + std::to_string(block_num) + "_g1").c_str()));
#else
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wcast-function-type"
  residual_fct = reinterpret_cast<dynamic_resid_fct>(
      GetProcAddress(mex, ("dynamic_"s + std::to_string(block_num) + "_resid").c_str()));
  g1_fct = reinterpret_cast<dynamic_g1_fct>(
      GetProcAddress(mex, ("dynamic_"s + std::to_string(block_num) + "_g1").c_str()));
# pragma GCC diagnostic pop
#endif
  if (!residual_fct || !g1_fct)
    mexErrMsgTxt("Can't load functions in dynamic MEX file");
}

void
DynamicModelBlockDllCaller::unload_dll()
{
#if !defined(__CYGWIN32__) && !defined(_WIN32)
  dlclose(mex);
#else
  FreeLibrary(mex);
#endif
}

DynamicModelBlockDllCaller::DynamicModelBlockDllCaller(size_t ntt, size_t mfs, mwIndex ny,
                                                       mwIndex nx, const double* params_arg,
                                                       const double* steady_state_arg,
                                                       const int32_T* g1_sparse_colptr_arg,
                                                       bool linear_arg, bool compute_jacobian_arg) :
    DynamicModelDllCaller {ntt,
                           ny,
                           nx,
                           params_arg,
                           steady_state_arg,
                           g1_sparse_colptr_arg,
                           linear_arg,
                           compute_jacobian_arg}
{
  if (compute_jacobian)
    jacobian_p.resize(g1_sparse_colptr[3 * mfs] - 1);
}

void
DynamicModelBlockDllCaller::eval(double* resid)
{
  residual_fct(y_p.data(), x_p.data(), params, steady_state, tt.data(), resid);
  if (compute_jacobian)
    {
      g1_fct(y_p.data(), x_p.data(), params, steady_state, tt.data(), jacobian_p.data());

      if (linear)
        compute_jacobian = false; // If model is linear, no need to recompute Jacobian later
    }
}

DynamicModelBlockMatlabCaller::DynamicModelBlockMatlabCaller(
    std::string basename_arg, int block_num_arg, mwIndex ntt, mwIndex ny, mwIndex nx,
    const mxArray* params_mx_arg, const mxArray* steady_state_mx_arg,
    const mxArray* g1_sparse_rowval_mx_arg, const mxArray* g1_sparse_colval_mx_arg,
    const mxArray* g1_sparse_colptr_mx_arg, bool linear_arg, bool compute_jacobian_arg) :
    DynamicModelMatlabCaller {std::move(basename_arg),
                              ny,
                              nx,
                              params_mx_arg,
                              steady_state_mx_arg,
                              g1_sparse_rowval_mx_arg,
                              g1_sparse_colval_mx_arg,
                              g1_sparse_colptr_mx_arg,
                              linear_arg,
                              compute_jacobian_arg},
    block_num {block_num_arg},
    T_mx {mxCreateDoubleMatrix(ntt, 1, mxREAL)}
{
}

DynamicModelBlockMatlabCaller::~DynamicModelBlockMatlabCaller()
{
  mxDestroyArray(T_mx);
}

void
DynamicModelBlockMatlabCaller::eval(double* resid)
{
  std::string funcname {basename + ".block.dynamic_" + std::to_string(block_num)};

  std::array prhs {y_mx,
                   x_mx,
                   params_mx,
                   steady_state_mx,
                   g1_sparse_rowval_mx,
                   g1_sparse_colval_mx,
                   g1_sparse_colptr_mx,
                   T_mx};
  std::vector<mxArray*> plhs(3 + static_cast<int>(compute_jacobian));
  mxArray* exception {
      mexCallMATLABWithTrap(plhs.size(), plhs.data(), prhs.size(), prhs.data(), funcname.c_str())};

  if (exception)
    {
      setMException(exception);
      return; // Avoid manipulating null pointers in plhs, see #1832
    }

  if (mxGetDoubles(y_mx) != mxGetDoubles(plhs[0])) // Under Octave, plhs[0] and prhs[0] share the
                                                   // same data array if y is unchanged (see #1996)
    mxDestroyArray(y_mx);
  if (mxIsComplex(plhs[0]))
    y_mx = cmplxToReal<false>(plhs[0]);
  else
    y_mx = plhs[0];

  if (mxGetDoubles(T_mx) != mxGetDoubles(plhs[1])) // Under Octave, plhs[1] and prhs[8] share the
                                                   // same data array if T is unchanged (see #1996)
    mxDestroyArray(T_mx);
  if (mxIsComplex(plhs[1]))
    T_mx = cmplxToReal<false>(plhs[1]);
  else
    T_mx = plhs[1];

  if (!mxIsDouble(plhs[2]) || mxIsSparse(plhs[2]))
    {
      setErrMsg("Residuals should be a dense array of double floats");
      return;
    }

  if (mxIsComplex(plhs[2]))
    plhs[2] = cmplxToReal<false>(plhs[2]);

  std::ranges::copy_n(mxGetDoubles(plhs[2]), mxGetNumberOfElements(plhs[2]), resid);
  mxDestroyArray(plhs[2]);

  if (compute_jacobian)
    {
      if (jacobian_mx)
        {
          mxDestroyArray(jacobian_mx);
          jacobian_mx = nullptr;
        }

      if (!mxIsDouble(plhs[3]) || !mxIsSparse(plhs[3]))
        {
          setErrMsg("Jacobian should be a dense array of double floats");
          return;
        }

      if (mxIsComplex(plhs[3]))
        jacobian_mx = cmplxToReal<true>(plhs[3]);
      else
        jacobian_mx = plhs[3];

      if (linear)
        compute_jacobian = false; // If model is linear, no need to recompute Jacobian later
    }
}
