/*
 * Copyright (C) 1996-2011 Daniel Waggoner and Tao Zha
 *
 * This free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * It is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * If you did not received a copy of the GNU General Public License
 * with this software, see <http://www.gnu.org/licenses/>.
 */

#include <string>

#if defined(MATLAB_MEX_FILE)
# include <dynmex.h>
# include "mat.h"
#else
# include <octave/oct.h>
# include <octave/interpreter.h>
#endif

extern "C"
{
#include "VARio.h"

/*
  Writes A0, Aplus, Zeta and the Base Transition Matrix to either
  a Matlab or an Octave .mat file
 */
int Write_VAR_Parameters_Mat(TStateModel *model, char *out_tag)
{
  T_VAR_Parameters *p=(T_VAR_Parameters*)(model->theta);

  TMatrix* A0 = MakeA0_All(nullptr, p);
  TMatrix* Aplus = MakeAplus_All(nullptr, p);
  TMatrix* Zeta = MakeZeta_All(nullptr, p);
  TMatrix Q = model->sv->baseQ;


  // Create variables in a format suitable for writing to the MAT file

  int nrows, ncols;
  double *data;
#if defined(MATLAB_MEX_FILE)
  mwSize dims[3];
#endif

  /* A0 */
  nrows = RowM(A0[0]);
  ncols = ColM(A0[0]);
#if defined(MATLAB_MEX_FILE)
  dims[0] = nrows;
  dims[1] = ncols;
  dims[2] = p->nstates;
  mxArray* A0mat = mxCreateNumericArray(3, (mwSize *)dims, mxDOUBLE_CLASS, mxREAL);
  data = mxGetPr(A0mat);
#else
  NDArray A0mat {dim_vector{nrows, ncols, p->nstates}};
# if OCTAVE_MAJOR_VERSION >= 10
  data = A0mat.rwdata();
# else
  data = A0mat.fortran_vec();
# endif
#endif
  for (int s=0; s < p->nstates; s++)
    for (int i=0; i < nrows; i++)
      for (int j=0; j < ncols; j++)
        data[s*(nrows*ncols)+i+j*nrows] = (double)(ElementM(A0[s], i, j));

  /* Aplus */
  nrows = RowM(Aplus[0]);
  ncols = ColM(Aplus[0]);
#if defined(MATLAB_MEX_FILE)
  dims[0] = nrows;
  dims[1] = ncols;
  mxArray* Aplusmat = mxCreateNumericArray(3, (mwSize *)dims, mxDOUBLE_CLASS, mxREAL);
  data = mxGetPr(Aplusmat);
#else
  NDArray Aplusmat {dim_vector{nrows, ncols, p->nstates}};
# if OCTAVE_MAJOR_VERSION >= 10
  data = Aplusmat.rwdata();
# else
  data = Aplusmat.fortran_vec();
# endif
#endif
  for (int s=0; s < p->nstates; s++)
    for (int i=0; i < nrows; i++)
      for (int j=0; j < ncols; j++)
        data[s*(nrows*ncols)+i+j*nrows] = (double)(ElementM(Aplus[s], i, j));

  /* Zeta */
  nrows = RowM(Zeta[0]);
  ncols = ColM(Zeta[0]);
#if defined(MATLAB_MEX_FILE)
  dims[0] = nrows;
  dims[1] = ncols;
  mxArray* Zetamat = mxCreateNumericArray(3, (mwSize *)dims, mxDOUBLE_CLASS, mxREAL);
  data = mxGetPr(Zetamat);
#else
  NDArray Zetamat {dim_vector{nrows, ncols, p->nstates}};
# if OCTAVE_MAJOR_VERSION >= 10
  data = Zetamat.rwdata();
# else
  data = Zetamat.fortran_vec();
# endif
#endif
  for (int s=0; s < p->nstates; s++)
    for (int i=0; i < nrows; i++)
      for (int j=0; j < ncols; j++)
        data[s*(nrows*ncols)+i+j*nrows] = (double)(ElementM(Zeta[s], i, j));

  /* Q */
  nrows = RowM(Q);
  ncols = ColM(Q);
#if defined(MATLAB_MEX_FILE)
  dims[0] = nrows = RowM(Q);
  dims[1] = ncols = ColM(Q);
  dims[2] = 0;
  mxArray* Qmat = mxCreateNumericArray(2, (mwSize *)dims, mxDOUBLE_CLASS, mxREAL);
  data = mxGetPr(Qmat);
#else
  Matrix Qmat {nrows, ncols};
# if OCTAVE_MAJOR_VERSION >= 10
  data = Qmat.rwdata();
# else
  data = Qmat.fortran_vec();
# endif
#endif
  for (int i=0; i < nrows; i++)
    for (int j=0; j < ncols; j++)
      data[i+j*nrows] = (double)(ElementM(Q, i, j));


  // Write the file

  std::string fullFilename {out_tag};
  fullFilename += ".mat";

#if defined(MATLAB_MEX_FILE)
  MATFile* matFile = matOpen(fullFilename.c_str(), "wz");
  if (matFile == nullptr)
    mexErrMsgTxt(("Error opening mat file " + fullFilename + " in Write_VAR_Parameters_Mat.").c_str());
  int status;

  status = matPutVariable(matFile, "A0", A0mat);
  if (status != 0)
    mexErrMsgTxt("Error writing A0 to matfile in Write_VAR_Parameters_Mat.");

  status = matPutVariable(matFile, "Aplus", Aplusmat);
  if (status != 0)
      mexErrMsgTxt("Error writing Aplus to matfile in Write_VAR_Parameters_Mat.");

  status = matPutVariable(matFile, "Zeta", Zetamat);
  if (status != 0)
    mexErrMsgTxt("Error writing Zeta to matfile in Write_VAR_Parameters_Mat.");

  status = matPutVariable(matFile, "Q", Qmat);
  if (status != 0)
    mexErrMsgTxt("Error writing Q to matfile in Write_VAR_Parameters_Mat.");

  matClose(matFile);
#else
  auto* interp = octave::interpreter::the_interpreter();

  interp->assign("A0", A0mat);
  interp->assign("Aplus", Aplusmat);
  interp->assign("Zeta", Zetamat);
  interp->assign("Q", Qmat);

  auto& lss = interp->get_load_save_system();

  octave_value_list save_args;
  save_args(0) = "-mat";
  save_args(1) = fullFilename;
  save_args(2) = "A0";
  save_args(3) = "Aplus";
  save_args(4) = "Zeta";
  save_args(5) = "Q";

  lss.save(save_args);
#endif

  for (int s=0; s < p->nstates; s++)
    {
      FreeMatrix(A0[s]);
      FreeMatrix(Aplus[s]);
      FreeMatrix(Zeta[s]);
    }
  return 1;
}

} // extern "C"
