/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2020 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/

/*
   This module contains the following operators:

      Mrotuv      mrotuv          Forward rotation for MPIOM data
*/

#include <cdi.h>

#include <mpim_grid.h>
#include "array.h"
#include "cdo_options.h"
#include "process_int.h"
#include "cdi_lockedIO.h"

void
rotate_uv(double *u_i, double *v_j, long ix, long iy, double *lon, double *lat, double *u_lon, double *v_lat)
{
  /*
    in      :: u_i(ix,iy),v_j(ix,iy)      vector components in i-j-direction
    in      :: lat(ix,iy),lon(ix,iy)      latitudes and longitudes
    out     :: u_lon(ix,iy),v_lat(ix,iy)  vector components in lon-lat direction
  */
  double pi = 3.14159265359;

  // specification whether change in sign is needed for the input arrays
  bool change_sign_u = false;
  bool change_sign_v = true;

  // initialization
  for (long i = 0; i < ix * iy; i++)
    {
      v_lat[i] = 0;
      u_lon[i] = 0;
    }

  // rotation
  for (long j = 0; j < iy; j++)
    for (long i = 0; i < ix; i++)
      {
        auto ip1 = i + 1;
        auto im1 = i - 1;
        auto jp1 = j + 1;
        auto jm1 = j - 1;
        if (ip1 >= ix) ip1 = 0; // the 0-meridian
        if (im1 < 0) im1 = ix - 1;
        if (jp1 >= iy) jp1 = j; // treatment of the last..
        if (jm1 < 0) jm1 = j;   // .. and the fist grid-row

        // difference in latitudes
        auto dlat_i = lat[IX2D(j, ip1, ix)] - lat[IX2D(j, im1, ix)];
        auto dlat_j = lat[IX2D(jp1, i, ix)] - lat[IX2D(jm1, i, ix)];

        // difference in longitudes
        auto dlon_i = lon[IX2D(j, ip1, ix)] - lon[IX2D(j, im1, ix)];
        if (dlon_i > pi) dlon_i -= 2 * pi;
        if (dlon_i < (-pi)) dlon_i += 2 * pi;
        auto dlon_j = lon[IX2D(jp1, i, ix)] - lon[IX2D(jm1, i, ix)];
        if (dlon_j > pi) dlon_j -= 2 * pi;
        if (dlon_j < (-pi)) dlon_j += 2 * pi;

        const auto lat_factor = std::cos(lat[IX2D(j, i, ix)]);
        dlon_i = dlon_i * lat_factor;
        dlon_j = dlon_j * lat_factor;

        // projection by scalar product
        u_lon[IX2D(j, i, ix)] = u_i[IX2D(j, i, ix)] * dlon_i + v_j[IX2D(j, i, ix)] * dlat_i;
        v_lat[IX2D(j, i, ix)] = u_i[IX2D(j, i, ix)] * dlon_j + v_j[IX2D(j, i, ix)] * dlat_j;

        auto dist_i = std::sqrt(dlon_i * dlon_i + dlat_i * dlat_i);
        auto dist_j = std::sqrt(dlon_j * dlon_j + dlat_j * dlat_j);

        if (std::fabs(dist_i) > 0 && std::fabs(dist_j) > 0)
          {
            u_lon[IX2D(j, i, ix)] /= dist_i;
            v_lat[IX2D(j, i, ix)] /= dist_j;
          }
        else
          {
            u_lon[IX2D(j, i, ix)] = 0.0;
            v_lat[IX2D(j, i, ix)] = 0.0;
          }

        // velocity vector lengths
        auto absold = std::sqrt(u_i[IX2D(j, i, ix)] * u_i[IX2D(j, i, ix)] + v_j[IX2D(j, i, ix)] * v_j[IX2D(j, i, ix)]);
        auto absnew = std::sqrt(u_lon[IX2D(j, i, ix)] * u_lon[IX2D(j, i, ix)] + v_lat[IX2D(j, i, ix)] * v_lat[IX2D(j, i, ix)]);

        u_lon[IX2D(j, i, ix)] *= absold;
        v_lat[IX2D(j, i, ix)] *= absold;

        if (absnew > 0)
          {
            u_lon[IX2D(j, i, ix)] /= absnew;
            v_lat[IX2D(j, i, ix)] /= absnew;
          }
        else
          {
            u_lon[IX2D(j, i, ix)] = 0.0;
            v_lat[IX2D(j, i, ix)] = 0.0;
          }

        // change sign
        if (change_sign_u) u_lon[IX2D(j, i, ix)] *= -1;
        if (change_sign_v) v_lat[IX2D(j, i, ix)] *= -1;

        if (Options::cdoVerbose)
          {
            absold = std::sqrt(u_i[IX2D(j, i, ix)] * u_i[IX2D(j, i, ix)] + v_j[IX2D(j, i, ix)] * v_j[IX2D(j, i, ix)]);
            absnew = std::sqrt(u_lon[IX2D(j, i, ix)] * u_lon[IX2D(j, i, ix)] + v_lat[IX2D(j, i, ix)] * v_lat[IX2D(j, i, ix)]);

            if (i % 20 == 0 && j % 20 == 0 && absold > 0)
              {
                printf("(absold,absnew) %ld %ld %g %g %g %g %g %g\n", j + 1, i + 1, absold, absnew, u_i[IX2D(j, i, ix)],
                       v_j[IX2D(j, i, ix)], u_lon[IX2D(j, i, ix)], v_lat[IX2D(j, i, ix)]);

                // test orthogonality
                if ((dlon_i * dlon_j + dlat_j * dlat_i) > 0.1)
                  fprintf(stderr, "orthogonal? %ld %ld %g\n", j + 1, i + 1, (dlon_i * dlon_j + dlat_j * dlat_i));
              }
          }
      }
}

void
p_to_uv_grid(long nlon, long nlat, double *grid1x, double *grid1y, double *gridux, double *griduy, double *gridvx, double *gridvy)
{
  // interpolate scalar to u points
  for (long j = 0; j < nlat; j++)
    for (long i = 0; i < nlon; i++)
      {
        auto ip1 = i + 1;
        if (ip1 > nlon - 1) ip1 = 0;

        gridux[IX2D(j, i, nlon)] = (grid1x[IX2D(j, i, nlon)] + grid1x[IX2D(j, ip1, nlon)]) * 0.5;
        if ((grid1x[IX2D(j, i, nlon)] > 340 && grid1x[IX2D(j, ip1, nlon)] < 20)
            || (grid1x[IX2D(j, i, nlon)] < 20 && grid1x[IX2D(j, ip1, nlon)] > 340))
          {
            if (gridux[IX2D(j, i, nlon)] < 180)
              gridux[IX2D(j, i, nlon)] += 180;
            else
              gridux[IX2D(j, i, nlon)] -= 180;
          }

        griduy[IX2D(j, i, nlon)] = (grid1y[IX2D(j, i, nlon)] + grid1y[IX2D(j, ip1, nlon)]) * 0.5;
      }

  // interpolate scalar to v points
  for (long j = 0; j < nlat; j++)
    for (long i = 0; i < nlon; i++)
      {
        auto jp1 = j + 1;
        if (jp1 > nlat - 1) jp1 = nlat - 1;

        gridvx[IX2D(j, i, nlon)] = (grid1x[IX2D(j, i, nlon)] + grid1x[IX2D(jp1, i, nlon)]) * 0.5;
        if ((grid1x[IX2D(j, i, nlon)] > 340 && grid1x[IX2D(jp1, i, nlon)] < 20)
            || (grid1x[IX2D(j, i, nlon)] < 20 && grid1x[IX2D(jp1, i, nlon)] > 340))
          {
            if (gridvx[IX2D(j, i, nlon)] < 180)
              gridvx[IX2D(j, i, nlon)] += 180;
            else
              gridvx[IX2D(j, i, nlon)] -= 180;
          }

        gridvy[IX2D(j, i, nlon)] = (grid1y[IX2D(j, i, nlon)] + grid1y[IX2D(jp1, i, nlon)]) * 0.5;
      }
}

void *
Mrotuv(void *process)
{
  int nrecs;
  int levelID;
  int varID, varid;
  size_t nmiss1 = 0, nmiss2 = 0;
  int uid = -1, vid = -1;

  cdoInitialize(process);

  operatorCheckArgc(0);

  const auto streamID1 = cdoOpenRead(0);

  const auto vlistID1 = cdoStreamInqVlist(streamID1);

  const auto nvars = vlistNvars(vlistID1);
  for (varid = 0; varid < nvars; varid++)
    {
      int code = vlistInqVarCode(vlistID1, varid);
      if (code == 3 || code == 131) uid = varid;
      if (code == 4 || code == 132) vid = varid;
    }

  if (uid == -1 || vid == -1)
    {
      if (nvars == 2)
        {
          uid = 0;
          vid = 1;
        }
      else
        cdoAbort("U and V not found in %s", cdoGetStreamName(0));
    }

  const auto nlevs = zaxisInqSize(vlistInqVarZaxis(vlistID1, uid));
  if (nlevs != zaxisInqSize(vlistInqVarZaxis(vlistID1, vid))) cdoAbort("U and V have different number of levels!");

  auto gridID1 = vlistInqVarGrid(vlistID1, uid);
  const auto gridID2 = vlistInqVarGrid(vlistID1, vid);
  const auto gridsize = gridInqSize(gridID1);
  if (gridID1 != gridID2) cdoAbort("Input grids differ!");

  if (gridInqType(gridID1) != GRID_LONLAT && gridInqType(gridID1) != GRID_GAUSSIAN && gridInqType(gridID1) != GRID_CURVILINEAR)
    cdoAbort("Grid %s unsupported!", gridNamePtr(gridInqType(gridID1)));

  if (gridInqType(gridID1) != GRID_CURVILINEAR) gridID1 = gridToCurvilinear(gridID1, 0);

  if (gridsize != gridInqSize(gridID1)) cdoAbort("Internal problem: gridsize changed!");

  const auto nlon = gridInqXsize(gridID1);
  const auto nlat = gridInqYsize(gridID1);

  Varray<double> grid1x(gridsize), grid1y(gridsize);
  Varray<double> gridux(gridsize), griduy(gridsize);
  Varray<double> gridvx(gridsize), gridvy(gridsize);

  const auto gridsizex = (nlon + 2) * nlat;

  gridInqXvals(gridID1, grid1x.data());
  gridInqYvals(gridID1, grid1y.data());

  // Convert lat/lon units if required
  cdo_grid_to_degree(gridID1, CDI_XAXIS, gridsize, grid1x.data(), "grid center lon");
  cdo_grid_to_degree(gridID1, CDI_YAXIS, gridsize, grid1y.data(), "grid center lat");

  p_to_uv_grid(nlon, nlat, grid1x.data(), grid1y.data(), gridux.data(), griduy.data(), gridvx.data(), gridvy.data());

  const auto gridIDu = gridCreate(GRID_CURVILINEAR, nlon * nlat);
  gridDefDatatype(gridIDu, gridInqDatatype(gridID1));
  gridDefXsize(gridIDu, nlon);
  gridDefYsize(gridIDu, nlat);
  gridDefXvals(gridIDu, gridux.data());
  gridDefYvals(gridIDu, griduy.data());

  const auto gridIDv = gridCreate(GRID_CURVILINEAR, nlon * nlat);
  gridDefDatatype(gridIDv, gridInqDatatype(gridID1));
  gridDefXsize(gridIDv, nlon);
  gridDefYsize(gridIDv, nlat);
  gridDefXvals(gridIDv, gridvx.data());
  gridDefYvals(gridIDv, gridvy.data());

  for (size_t i = 0; i < gridsize; i++)
    {
      grid1x[i] *= DEG2RAD;
      grid1y[i] *= DEG2RAD;
    }

  vlistClearFlag(vlistID1);
  for (int lid = 0; lid < nlevs; lid++) vlistDefFlag(vlistID1, uid, lid, true);
  const auto vlistID2 = vlistCreate();
  cdoVlistCopyFlag(vlistID2, vlistID1);
  vlistChangeVarGrid(vlistID2, 0, gridIDu);

  vlistClearFlag(vlistID1);
  for (int lid = 0; lid < nlevs; lid++) vlistDefFlag(vlistID1, vid, lid, true);
  const auto vlistID3 = vlistCreate();
  cdoVlistCopyFlag(vlistID3, vlistID1);
  vlistChangeVarGrid(vlistID3, 0, gridIDv);

  const auto taxisID1 = vlistInqTaxis(vlistID1);
  const auto taxisID2 = taxisDuplicate(taxisID1);
  const auto taxisID3 = taxisDuplicate(taxisID1);
  vlistDefTaxis(vlistID2, taxisID2);
  vlistDefTaxis(vlistID3, taxisID3);

  const auto streamID2 = cdoOpenWrite(1);
  const auto streamID3 = cdoOpenWrite(2);

  cdoDefVlist(streamID2, vlistID2);
  cdoDefVlist(streamID3, vlistID3);

  const auto missval1 = vlistInqVarMissval(vlistID1, uid);
  const auto missval2 = vlistInqVarMissval(vlistID1, vid);

  Varray<double> ufield(gridsize), vfield(gridsize);

  Varray2D<double> urfield(nlevs), vrfield(nlevs);
  for (int lid = 0; lid < nlevs; lid++)
    {
      urfield[lid].resize(gridsize);
      vrfield[lid].resize(gridsize);
    }

  Varray<double> uhelp(gridsizex), vhelp(gridsizex);

  int tsID = 0;
  while ((nrecs = cdoStreamInqTimestep(streamID1, tsID)))
    {
      taxisCopyTimestep(taxisID2, taxisID1);
      cdoDefTimestep(streamID2, tsID);
      taxisCopyTimestep(taxisID3, taxisID1);
      cdoDefTimestep(streamID3, tsID);

      for (int recID = 0; recID < nrecs; recID++)
        {
          cdoInqRecord(streamID1, &varID, &levelID);

          if (varID == uid) cdoReadRecord(streamID1, urfield[levelID].data(), &nmiss1);
          if (varID == vid) cdoReadRecord(streamID1, vrfield[levelID].data(), &nmiss2);
        }

      for (levelID = 0; levelID < nlevs; levelID++)
        {
          // remove missing values
          if (nmiss1 || nmiss2)
            {
              for (size_t i = 0; i < gridsize; i++)
                {
                  if (DBL_IS_EQUAL(urfield[levelID][i], missval1)) urfield[levelID][i] = 0;
                  if (DBL_IS_EQUAL(vrfield[levelID][i], missval2)) vrfield[levelID][i] = 0;
                }
            }

          // rotate
          rotate_uv(urfield[levelID].data(), vrfield[levelID].data(), nlon, nlat, grid1x.data(), grid1y.data(), ufield.data(),
                    vfield.data());

          // load to a help field
          for (size_t j = 0; j < nlat; j++)
            for (size_t i = 0; i < nlon; i++)
              {
                uhelp[IX2D(j, i + 1, nlon + 2)] = ufield[IX2D(j, i, nlon)];
                vhelp[IX2D(j, i + 1, nlon + 2)] = vfield[IX2D(j, i, nlon)];
              }

          // make help field cyclic
          for (size_t j = 0; j < nlat; j++)
            {
              uhelp[IX2D(j, 0, nlon + 2)] = uhelp[IX2D(j, nlon, nlon + 2)];
              uhelp[IX2D(j, nlon + 1, nlon + 2)] = uhelp[IX2D(j, 1, nlon + 2)];
              vhelp[IX2D(j, 0, nlon + 2)] = vhelp[IX2D(j, nlon, nlon + 2)];
              vhelp[IX2D(j, nlon + 1, nlon + 2)] = vhelp[IX2D(j, 1, nlon + 2)];
            }

          // interpolate on u/v points
          for (size_t j = 0; j < nlat; j++)
            for (size_t i = 0; i < nlon; i++)
              {
                ufield[IX2D(j, i, nlon)] = (uhelp[IX2D(j, i + 1, nlon + 2)] + uhelp[IX2D(j, i + 2, nlon + 2)]) * 0.5;
              }

          for (size_t j = 0; j < nlat - 1; j++)
            for (size_t i = 0; i < nlon; i++)
              {
                vfield[IX2D(j, i, nlon)] = (vhelp[IX2D(j, i + 1, nlon + 2)] + vhelp[IX2D(j + 1, i + 1, nlon + 2)]) * 0.5;
              }

          for (size_t i = 0; i < nlon; i++)
            {
              vfield[IX2D(nlat - 1, i, nlon)] = vhelp[IX2D(nlat - 1, i + 1, nlon + 2)];
            }

          cdoDefRecord(streamID2, 0, levelID);
          cdoWriteRecord(streamID2, ufield.data(), nmiss1);
          cdoDefRecord(streamID3, 0, levelID);
          cdoWriteRecord(streamID3, vfield.data(), nmiss2);
        }

      tsID++;
    }

  cdoStreamClose(streamID3);
  cdoStreamClose(streamID2);
  cdoStreamClose(streamID1);

  cdoFinish();

  return nullptr;
}
