#include <vector>
#include <cmath>
#include <limits>
#include <algorithm>
#include <numeric>
#include "NumericUtils.h"
#include "CppStats.h"
#include "CppDistances.h"
#include <RcppThread.h>

// [[Rcpp::depends(RcppThread)]]

/*
 * Compute the False Nearest Neighbors (FNN) ratio for spatial cross-sectional data.
 *
 * This function determines whether nearest neighbors identified in a lower-dimensional
 * embedded space (E1) remain close in a higher-dimensional space (E2).
 * If not, the neighbor is considered a "false" neighbor, indicating the need for
 * a higher embedding dimension to accurately capture spatial proximity.
 *
 * The FNN test is computed in two modes:
 * - parallel_level = 0: each prediction is processed in parallel using RcppThreads.
 * - parallel_level = 1: all pairwise distances are precomputed once in advance
 *   (better for repeated queries or small prediction sets).
 *
 * Parameters:
 * - embedding: A matrix (vector of vectors) representing the spatial embedding,
 *              where each row corresponds to a spatial unit's attributes.
 *              Must contain at least E2 columns.
 * - lib: Library index vector (1-based in R, converted to 0-based).
 * - pred: Prediction index vector (1-based in R, converted to 0-based).
 * - E1: The base embedding dimension used to identify the nearest neighbor (E1 < E2).
 * - E2: The full embedding dimension used to test false neighbors (usually E1 + 1).
 * - threads: Number of threads used when parallel_level = 0.
 * - parallel_level: 0 for per-pred parallelism (default), 1 for precomputed full distance matrix.
 * - Rtol: Relative threshold (default 10.0).
 * - Atol: Absolute threshold (default 2.0).
 * - L1norm: Whether to use Manhattan (L1) distance instead of Euclidean (L2).
 *
 * Returns:
 * - A double value indicating the proportion of false nearest neighbors (0–1).
 *   If no valid pairs are found, returns NaN.
 */
double CppSingleFNN(const std::vector<std::vector<double>>& embedding,
                    const std::vector<size_t>& lib,
                    const std::vector<size_t>& pred,
                    size_t E1,
                    size_t E2,
                    size_t threads,
                    int parallel_level = 0,
                    double Rtol = 10.0,
                    double Atol = 2.0,
                    bool L1norm = false) {
  if (embedding.empty() || embedding[0].size() < E2) {
    return std::numeric_limits<double>::quiet_NaN();  // Invalid dimensions
  }

  size_t N = embedding.size();

  if (parallel_level != 0){
    size_t false_count = 0;
    size_t total = 0;

    // Brute-force linear search
    for (size_t i = 0; i < pred.size(); ++i) {
      if (checkOneDimVectorNotNanNum(embedding[pred[i]]) == 0) {
        continue;  // Skip rows with all NaNs
      }

      double min_dist = std::numeric_limits<double>::max();
      size_t nn_idx = N;  // invalid index placeholder

      // Find nearest neighbor of i in E1-dimensional space
      for (size_t j = 0; j < lib.size(); ++j) {
        if (pred[i] == lib[j] || checkOneDimVectorNotNanNum(embedding[lib[j]]) == 0) continue;
        
        // Compute distance in E1-dimensional space
        std::vector<double> xi(embedding[pred[i]].begin(), embedding[pred[i]].begin() + E1);
        std::vector<double> xj(embedding[lib[j]].begin(), embedding[lib[j]].begin() + E1);
        double dist = CppDistance(xi, xj, L1norm, true);

        if (dist < min_dist) {
          min_dist = dist;
          nn_idx = lib[j];
        }
      }

      if (nn_idx == N || doubleNearlyEqual(min_dist,0.0)) continue;  // skip degenerate cases

      // Compare E2-th coordinate difference (new dimension)
      double diff = std::abs(embedding[pred[i]][E2 - 1] - embedding[nn_idx][E2 - 1]);
      double ratio = diff / min_dist;

      if (ratio > Rtol || diff > Atol) {
        ++false_count;
      }
      ++total;
    }

    return total > 0 ? static_cast<double>(false_count) / static_cast<double>(total)
    : std::numeric_limits<double>::quiet_NaN();
  } else {
    // Parallel version: allocate one slot for each pred[i], thread-safe without locks
    std::vector<int> false_flags(pred.size(), -1); // -1 means skip or invalid, 0 means not a false neighbor, 1 means false neighbor

    RcppThread::parallelFor(0, pred.size(), [&](size_t i) {
      int pidx = pred[i];
      if (checkOneDimVectorNotNanNum(embedding[pidx]) == 0) return;

      double min_dist = std::numeric_limits<double>::max();
      int nn_idx = -1;

      for (size_t j = 0; j < lib.size(); ++j) {
        int lidx = lib[j];
        if (pidx == lidx || checkOneDimVectorNotNanNum(embedding[lidx]) == 0) continue;

        // Compute distance using only the first E1 dimensions
        std::vector<double> xi(embedding[pidx].begin(), embedding[pidx].begin() + E1);
        std::vector<double> xj(embedding[lidx].begin(), embedding[lidx].begin() + E1);
        double dist = CppDistance(xi, xj, L1norm, true);

        if (dist < min_dist) {
          min_dist = dist;
          nn_idx = lidx;
        }
      }

      // Skip if no neighbor found or minimum distance is zero
      if (nn_idx == -1 || doubleNearlyEqual(min_dist,0.0)) return;

      // Compare the E2-th dimension to check for false neighbors
      double diff = std::abs(embedding[pidx][E2 - 1] - embedding[nn_idx][E2 - 1]);
      double ratio = diff / min_dist;

      // Determine if this is a false neighbor
      if (ratio > Rtol || diff > Atol) {
        false_flags[i] = 1;
      } else {
        false_flags[i] = 0;
      }
    }, threads); // use specified number of threads

    // After parallel section, aggregate results
    size_t false_count = 0, total = 0;
    for (int flag : false_flags) {
      if (flag >= 0) {
        total++;
        if (flag == 1) false_count++;
      }
    }

    if (total > 0) {
      return static_cast<double>(false_count) / static_cast<double>(total);
    } else {
      return std::numeric_limits<double>::quiet_NaN();
    }
  }
}

/*
 * Compute the False Nearest Neighbors (FNN) ratio for 3D embeddings.
 *
 * Embedding structure:
 *   embedding[e][unit][lag]
 *     e    = embedding level
 *     unit = spatial index
 *     lag  = lagged coordinate
 *
 * Distance definitions:
 *   Dist_E1 = mean_{e = 0 .. E1-1} distance( embedding[e][pred], embedding[e][lib] )
 *
 *   Dist_E2 = mean_{lag = 0 .. embedding[E2-1][pred].size()} abs( embedding[E2-1][pred][lag] - embedding[E2-1][lib][lag] )
 *
 * A false neighbor is flagged if:
 *       Dist_E2 / Dist_E1 > Rtol    OR    Dist_E2 > Atol
 *
 * Supports two computation modes:
 *   parallel_level = 0  → per-pred parallel computation
 *   parallel_level = 1  → precompute distance tables to reuse for repeated queries
 *
 * Returns:
 *   proportion of false nearest neighbors in [0,1], or NaN if none are valid.
 */
double CppSingleFNN(const std::vector<std::vector<std::vector<double>>>& embedding,
                    const std::vector<size_t>& lib,
                    const std::vector<size_t>& pred,
                    size_t E1,
                    size_t E2,
                    size_t threads,
                    int parallel_level = 0,
                    double Rtol = 10.0,
                    double Atol = 2.0,
                    bool L1norm = false){
    if (embedding.empty()) return std::numeric_limits<double>::quiet_NaN();

    size_t maxE = embedding.size();
    if (E1 == 0 || E1 >= E2 || E2 > maxE) {
        return std::numeric_limits<double>::quiet_NaN();
    }

    size_t N = embedding[0].size();
    if (N == 0) return std::numeric_limits<double>::quiet_NaN();

    std::vector<char> valid(N, 0);
    for (size_t i = 0; i < N; ++i) {
        for (size_t e = 0; e < maxE; ++e) {
            if (checkOneDimVectorNotNanNum(embedding[e][i]) > 0) {
                valid[i] = 1;
                break;
            }
        }
    }

    // -------------------------
    // parallel_level != 0
    // -------------------------
    if (parallel_level != 0) {
        // Evaluate FNN ratio
        size_t total = 0, false_count = 0;

        for (size_t pi = 0; pi < pred.size(); ++pi) {
            size_t iidx = pred[pi];
            if (!valid[iidx]) continue;

            double best_mean = std::numeric_limits<double>::max();
            size_t best_j = N;

            for (size_t lj = 0; lj < lib.size(); ++lj) {
                size_t jidx = lib[lj];
                if (!valid[jidx] || jidx == iidx) continue;

                double sum = 0.0;
                size_t num_e1_dist = 0;

                for (size_t e = 0; e < E1; ++e) {
                    double d = CppDistance(embedding[e][iidx], embedding[e][jidx], L1norm, true);
                    if (!std::isnan(d)) {
                      sum += d;
                      ++num_e1_dist;
                    }
                }
                if (num_e1_dist == 0) continue;

                double mean_d = sum / static_cast<double>(num_e1_dist);
                if (mean_d < best_mean) {
                    best_mean = mean_d;
                    best_j = jidx;
                }
            }

            if (best_j == N || doubleNearlyEqual(best_mean, 0.0)) continue;
            
            double diff = 0.0;
            size_t num_e2_dist = 0;
            
            for (size_t lagidx = 0; lagidx < embedding[E2 - 1][iidx].size(); ++lagidx) {
              double d = std::abs(embedding[E2 - 1][iidx][lagidx] - embedding[E2 - 1][best_j][lagidx]);
              if (!std::isnan(d)) {
                diff += d;
                ++num_e2_dist;
              }
            }

            if (num_e2_dist == 0) continue;

            double mean_e2 = diff / static_cast<double>(num_e2_dist);
            double ratio = mean_e2 / best_mean;
            if (ratio > Rtol || mean_e2 > Atol) ++false_count;
            ++total;
        }

        return total > 0 ? double(false_count) / double(total)
                         : std::numeric_limits<double>::quiet_NaN();
    } else {
      // -------------------------
      // parallel_level == 0 (parallel per pred)
      // -------------------------
      std::vector<int> flags(pred.size(), -1);

      RcppThread::parallelFor(0, pred.size(), [&](size_t pi) {
          size_t iidx = pred[pi];
          if (!valid[iidx]) return;

          double best_mean = std::numeric_limits<double>::max();
          size_t best_j = N;

          for (size_t lj = 0; lj < lib.size(); ++lj) {
              size_t jidx = lib[lj];
              if (!valid[jidx] || jidx == iidx) continue;

              double sum = 0.0;
              size_t num_e1_dist = 0;

              for (size_t e = 0; e < E1; ++e) {
                  double d = CppDistance(embedding[e][iidx], embedding[e][jidx], L1norm, true);
                  if (!std::isnan(d)) {
                    sum += d;
                    ++num_e1_dist;
                  }
              }
              if (num_e1_dist == 0) continue;

              double mean_d = sum / static_cast<double>(num_e1_dist);
              if (mean_d < best_mean) {
                  best_mean = mean_d;
                  best_j = jidx;
              }
          }

          if (best_j == N || doubleNearlyEqual(best_mean, 0.0)) return;
          
          double diff = 0.0;
          size_t num_e2_dist = 0;
          
          for (size_t lagidx = 0; lagidx < embedding[E2 - 1][iidx].size(); ++lagidx) {
            double d = std::abs(embedding[E2 - 1][iidx][lagidx] - embedding[E2 - 1][best_j][lagidx]);
            if (!std::isnan(d)) {
              diff += d;
              ++num_e2_dist;
            }
          }

          if (num_e2_dist == 0) return;

          double mean_e2 = diff / static_cast<double>(num_e2_dist);
          double ratio = mean_e2 / best_mean;
          flags[pi] = (ratio > Rtol || mean_e2 > Atol) ? 1 : 0;

      }, threads);

      size_t total = 0, false_count = 0;
      for (int f : flags) {
          if (f >= 0) {
              ++total;
              if (f == 1) ++false_count;
          }
      }

      return total > 0 ? double(false_count) / double(total)
                      : std::numeric_limits<double>::quiet_NaN();
      }
}

/*
 * Compute False Nearest Neighbor (FNN) ratios across multiple embedding dimensions
 * for spatial cross-sectional data.
 *
 * For a given embedding matrix (with each row representing a spatial unit and
 * each column an embedding dimension), this function evaluates the proportion
 * of false nearest neighbors (FNN) as the embedding dimension increases.
 *
 * It iteratively calls `CppSingleFNN` for each embedding dimension pair (E1, E2),
 * where E1 ranges from 1 to D - 1 (D = number of columns), and E2 = E1 + 1.
 * The FNN ratio measures how often a nearest neighbor in dimension E1 becomes
 * distant in dimension E2, suggesting that E1 is insufficient for reconstructing
 * the system.
 *
 * If `parallel_level == 0`, the function executes in serial;
 * otherwise, it uses multithreading to compute FNN ratios for each (E1, E2) pair
 * in parallel.
 *
 * Parameters:
 * - embedding: A vector of vectors where each row is a spatial unit’s embedding.
 *              Must have at least 2 columns (dimensions).
 * - lib: A vector of indices indicating the library set (0-based).
 * - pred: A vector of indices indicating the prediction set (0-based).
 * - Rtol: A vector of relative distance thresholds (one per E1).
 * - Atol: A vector of absolute distance thresholds (one per E1).
 * - L1norm: If true, use L1 (Manhattan) distance; otherwise, use L2 (Euclidean).
 * - threads: Number of threads to use for parallel computation.
 * - parallel_level: 0 for serial loop over E1, >0 for parallel loop over E1.
 *
 * Returns:
 * - A vector of FNN ratios corresponding to each E1 from 1 to D - 1.
 *   If not computable for a given E1, NaN is returned at that position.
 */
std::vector<double> CppFNN(const std::vector<std::vector<double>>& embedding,
                           const std::vector<size_t>& lib,
                           const std::vector<size_t>& pred,
                           const std::vector<double>& Rtol,
                           const std::vector<double>& Atol,
                           bool L1norm = false,
                           int threads = 8,
                           int parallel_level = 0) {
  // Configure threads
  size_t threads_sizet = static_cast<size_t>(std::abs(threads));
  threads_sizet = std::min(static_cast<size_t>(std::thread::hardware_concurrency()), threads_sizet);

  size_t max_E2 = embedding[0].size();
  std::vector<double> results(max_E2 - 1, std::numeric_limits<double>::quiet_NaN());

  if (embedding.empty() || embedding[0].size() < 2) {
    return results;  // Not enough dimensions to compute FNN
  }

  if (parallel_level == 0){
    // Loop through E1 = 1 to max_E2 - 1
    for (size_t E1 = 1; E1 < max_E2; ++E1) {
      size_t E2 = E1 + 1;
      double fnn_ratio = CppSingleFNN(embedding, lib, pred, E1, E2, threads_sizet,
                                      parallel_level, Rtol[E1 - 1], Atol[E1 - 1], L1norm);
      results[E1 - 1] = fnn_ratio;
    }
  } else {
    // Parallel computation
    RcppThread::parallelFor(1, max_E2, [&](size_t E1) {
      size_t E2 = E1 + 1;
      double fnn_ratio = CppSingleFNN(embedding, lib, pred, E1, E2, threads_sizet,
                                      parallel_level, Rtol[E1 - 1], Atol[E1 - 1], L1norm);
      results[E1 - 1] = fnn_ratio;
    }, threads_sizet);
  }

  return results;
}

/*
 * Compute FNN ratios for 3D embeddings across all embedding dimensions E1 = 1 .. D-1
 *
 * Returns: std::vector<double> of size D-1
 */
std::vector<double> CppFNN(const std::vector<std::vector<std::vector<double>>>& embedding,
                           const std::vector<size_t>& lib,
                           const std::vector<size_t>& pred,
                           const std::vector<double>& Rtol,
                           const std::vector<double>& Atol,
                           bool L1norm = false,
                           int threads = 8,
                           int parallel_level = 0){
    size_t D = embedding.size();
    if (D < 2) return std::vector<double>(1,std::numeric_limits<double>::quiet_NaN());

    std::vector<double> out(D - 1, std::numeric_limits<double>::quiet_NaN());

    // Configure threads
    size_t threads_sizet = static_cast<size_t>(std::abs(threads));
    threads_sizet = std::min(static_cast<size_t>(std::thread::hardware_concurrency()), threads_sizet);

    if (parallel_level == 0) {
        for (size_t E1 = 1; E1 < D; ++E1) {
            out[E1 - 1] = CppSingleFNN(
                embedding, lib, pred,
                E1, E1 + 1,
                threads_sizet, parallel_level,
                Rtol[E1 - 1], Atol[E1 - 1],
                L1norm
            );
        }
    } else {
        RcppThread::parallelFor(1, D, [&](size_t E1) {
            out[E1 - 1] = CppSingleFNN(
                embedding, lib, pred,
                E1, E1 + 1,
                threads_sizet, parallel_level,
                Rtol[E1 - 1], Atol[E1 - 1],
                L1norm
            );
        }, threads_sizet);
    }

    return out;
}
