#' @title Partial Least Squares Regression Analysis
#' @name pls_analysis
#' @description Functions for multivariate PLS regression with cross-validation.
NULL

#' Fit PLS Regression with Cross-Validation Component Selection
#'
#' Fits a partial least squares regression model with automatic selection
#' of the optimal number of components via cross-validation.
#'
#' @param X_matrix Numeric matrix of predictor variables (direct prices).
#' @param Y_matrix Numeric matrix of response variables (production prices).
#' @param max_components Maximum number of components to consider.
#'   Default NULL uses min(ncol(X), nrow(X)-1, ncol(Y), 25).
#' @param cv_segments Number of cross-validation segments. Default 10.
#' @param scale Logical. Scale variables before fitting. Default TRUE.
#' @param center Logical. Center variables before fitting. Default TRUE.
#'
#' @return A list containing:
#' \describe{
#'   \item{model}{The fitted pls model object}
#'   \item{optimal_ncomp}{Optimal number of components by CV-RMSE}
#'   \item{cv_table}{Data frame with CV metrics by number of components}
#'   \item{metrics_cv}{CV metrics at optimal component number}
#'   \item{metrics_insample}{In-sample metrics at optimal component number}
#' }
#'
#' @details
#' This function uses the pls package for PLS regression. Component
#' selection is based on minimizing cross-validated RMSE. The function
#' handles log-transformed data and reports metrics in both log and
#' original scales.
#'
#' @examples
#' \donttest{
#' if (requireNamespace("pls", quietly = TRUE)) {
#'   set.seed(123)
#'   n <- 50
#'   p <- 10
#'   X <- matrix(rnorm(n * p), n, p)
#'   colnames(X) <- paste0("X", 1:p)
#'   Y <- X[, 1:3] %*% diag(c(1, 0.5, 0.3)) + matrix(rnorm(n * 3, 0, 0.5), n, 3)
#'   colnames(Y) <- paste0("Y", 1:3)
#'
#'   result <- fit_pls_multivariate(X, Y, max_components = 8)
#'   print(result$optimal_ncomp)
#'   print(result$cv_table)
#' }
#' }
#'
#' @export
fit_pls_multivariate <- function(X_matrix,
                                  Y_matrix,
                                  max_components = NULL,
                                  cv_segments = 10L,
                                  scale = TRUE,
                                  center = TRUE) {

    check_package("pls", "partial least squares regression")

    if (!is.matrix(X_matrix)) {
        X_matrix <- as.matrix(X_matrix)
    }
    if (!is.matrix(Y_matrix)) {
        Y_matrix <- as.matrix(Y_matrix)
    }

    if (nrow(X_matrix) != nrow(Y_matrix)) {
        stop("X_matrix and Y_matrix must have the same number of rows.")
    }

    complete_rows <- stats::complete.cases(cbind(X_matrix, Y_matrix)) &
        apply(is.finite(cbind(X_matrix, Y_matrix)), 1L, all)

    if (!all(complete_rows)) {
        n_removed <- sum(!complete_rows)
        warning(sprintf("Removed %d rows with missing or non-finite values.", n_removed))
        X_matrix <- X_matrix[complete_rows, , drop = FALSE]
        Y_matrix <- Y_matrix[complete_rows, , drop = FALSE]
    }

    n_obs <- nrow(X_matrix)
    p_x <- ncol(X_matrix)
    p_y <- ncol(Y_matrix)

    max_feasible <- max(1L, min(p_x, n_obs - 1L, p_y))

    if (is.null(max_components)) {
        max_components <- min(max_feasible, 25L)
    } else {
        max_components <- min(max_components, max_feasible)
    }

    k_segments <- min(cv_segments, n_obs)

    pls_model <- pls::plsr(
        Y_matrix ~ X_matrix,
        ncomp = max_components,
        validation = "CV",
        segments = k_segments,
        segment.type = "consecutive",
        scale = scale,
        center = center
    )

    cv_predictions <- pls_model$validation$pred

    if (is.null(cv_predictions)) {
        stop("Cross-validation predictions not available.")
    }

    ncomp_fitted <- dim(cv_predictions)[3L]

    cv_table <- compute_pls_cv_metrics(Y_matrix, cv_predictions, ncomp_fitted)

    min_rmse <- min(cv_table$CV_RMSE_log, na.rm = TRUE)
    optimal_ncomp <- min(cv_table$components[cv_table$CV_RMSE_log == min_rmse])

    pred_train_arr <- stats::predict(pls_model, ncomp = optimal_ncomp)
    pred_train <- matrix(pred_train_arr[, , 1L, drop = TRUE], nrow = n_obs, ncol = p_y)

    metrics_insample <- compute_multivariate_metrics(Y_matrix, pred_train)

    metrics_cv <- list(
        ncomp = optimal_ncomp,
        RMSE_log = cv_table$CV_RMSE_log[optimal_ncomp],
        MAE_log = cv_table$CV_MAE_log[optimal_ncomp],
        R2_CV = cv_table$CV_R2[optimal_ncomp]
    )

    list(
        model = pls_model,
        optimal_ncomp = optimal_ncomp,
        cv_table = cv_table,
        metrics_cv = metrics_cv,
        metrics_insample = metrics_insample,
        n_obs = n_obs,
        p_x = p_x,
        p_y = p_y
    )
}


#' Compute PLS Cross-Validation Metrics
#'
#' Internal function to compute CV metrics for each component number.
#'
#' @param Y_actual Actual Y matrix.
#' @param cv_predictions 3D array of CV predictions [obs, vars, components].
#' @param ncomp Number of components fitted.
#'
#' @return Data frame with CV metrics by component.
#'
#' @keywords internal
compute_pls_cv_metrics <- function(Y_actual, cv_predictions, ncomp) {

    n_obs <- nrow(Y_actual)
    p_y <- ncol(Y_actual)

    cv_table <- lapply(seq_len(ncomp), function(k) {

        yhat <- matrix(cv_predictions[, , k, drop = TRUE], nrow = n_obs, ncol = p_y)

        r2 <- compute_multivariate_r2(Y_actual, yhat)
        rmse <- compute_multivariate_rmse(Y_actual, yhat)
        mae <- compute_multivariate_mae(Y_actual, yhat)

        valid <- is.finite(Y_actual) & is.finite(yhat)
        rmse_orig <- sqrt(mean((exp(Y_actual[valid]) - exp(yhat[valid]))^2))
        mae_orig <- mean(abs(exp(Y_actual[valid]) - exp(yhat[valid])))

        data.frame(
            components = k,
            CV_R2 = r2,
            CV_RMSE_log = rmse,
            CV_MAE_log = mae,
            CV_RMSE_orig = rmse_orig,
            CV_MAE_orig = mae_orig,
            stringsAsFactors = FALSE
        )
    })

    do.call(rbind, cv_table)
}


#' Compute Multivariate R-squared
#'
#' @param Y_actual Actual Y matrix.
#' @param Y_predicted Predicted Y matrix.
#'
#' @return Numeric R-squared value.
#'
#' @keywords internal
compute_multivariate_r2 <- function(Y_actual, Y_predicted) {

    sse <- 0
    sst <- 0

    for (j in seq_len(ncol(Y_actual))) {
        valid_j <- is.finite(Y_actual[, j]) & is.finite(Y_predicted[, j])
        if (any(valid_j)) {
            y_j <- Y_actual[valid_j, j]
            yhat_j <- Y_predicted[valid_j, j]
            sse <- sse + sum((y_j - yhat_j)^2)
            sst <- sst + sum((y_j - mean(y_j))^2)
        }
    }

    if (sst <= 0) {
        return(NA_real_)
    }

    1 - sse / sst
}


#' Compute Multivariate RMSE
#'
#' @param Y_actual Actual Y matrix.
#' @param Y_predicted Predicted Y matrix.
#'
#' @return Numeric RMSE value.
#'
#' @keywords internal
compute_multivariate_rmse <- function(Y_actual, Y_predicted) {

    valid <- is.finite(Y_actual) & is.finite(Y_predicted)
    sqrt(mean((Y_actual[valid] - Y_predicted[valid])^2))
}


#' Compute Multivariate MAE
#'
#' @param Y_actual Actual Y matrix.
#' @param Y_predicted Predicted Y matrix.
#'
#' @return Numeric MAE value.
#'
#' @keywords internal
compute_multivariate_mae <- function(Y_actual, Y_predicted) {

    valid <- is.finite(Y_actual) & is.finite(Y_predicted)
    mean(abs(Y_actual[valid] - Y_predicted[valid]))
}


#' Compute Multivariate Metrics
#'
#' Internal function to compute all metrics for multivariate predictions.
#'
#' @param Y_actual Actual Y matrix.
#' @param Y_predicted Predicted Y matrix.
#'
#' @return List of metrics.
#'
#' @keywords internal
compute_multivariate_metrics <- function(Y_actual, Y_predicted) {

    mae_log <- compute_multivariate_mae(Y_actual, Y_predicted)
    rmse_log <- compute_multivariate_rmse(Y_actual, Y_predicted)

    valid <- is.finite(Y_actual) & is.finite(Y_predicted)
    mae_orig <- mean(abs(exp(Y_actual[valid]) - exp(Y_predicted[valid])))
    rmse_orig <- sqrt(mean((exp(Y_actual[valid]) - exp(Y_predicted[valid]))^2))

    range_actual <- diff(range(as.vector(Y_actual), na.rm = TRUE))
    mae_rel_range <- if (range_actual > 0) mae_log / range_actual * 100 else NA_real_

    list(
        mae_log = mae_log,
        rmse_log = rmse_log,
        mae_orig = mae_orig,
        rmse_orig = rmse_orig,
        mae_rel_range = mae_rel_range
    )
}


#' Extract PLS Variable Importance
#'
#' Extracts variable importance scores from a fitted PLS model.
#'
#' @param pls_result Result from fit_pls_multivariate.
#' @param ncomp Number of components to use. Default uses optimal.
#'
#' @return Data frame with variable names and importance scores.
#'
#' @examples
#' \donttest{
#' if (requireNamespace("pls", quietly = TRUE)) {
#'   set.seed(123)
#'   n <- 50
#'   p <- 10
#'   X <- matrix(rnorm(n * p), n, p)
#'   colnames(X) <- paste0("X", 1:p)
#'   Y <- X[, 1:3] %*% diag(c(1, 0.5, 0.3)) + matrix(rnorm(n * 3, 0, 0.5), n, 3)
#'   colnames(Y) <- paste0("Y", 1:3)
#'
#'   result <- fit_pls_multivariate(X, Y, max_components = 5)
#'   importance <- extract_pls_importance(result)
#'   print(head(importance))
#' }
#' }
#'
#' @export
extract_pls_importance <- function(pls_result, ncomp = NULL) {

    check_package("pls", "PLS analysis")

    if (is.null(ncomp)) {
        ncomp <- pls_result$optimal_ncomp
    }

    model <- pls_result$model

    loadings <- pls::loadings(model)

    if (is.null(loadings)) {
        warning("Could not extract loadings from PLS model.")
        return(NULL)
    }

    loadings_mat <- as.matrix(loadings[, seq_len(ncomp), drop = FALSE])

    importance <- apply(loadings_mat^2, 1L, sum)
    importance <- sqrt(importance)

    result <- data.frame(
        variable = rownames(loadings_mat),
        importance = importance,
        stringsAsFactors = FALSE
    )

    result <- result[order(result$importance, decreasing = TRUE), ]
    rownames(result) <- NULL

    result
}


#' Predict from PLS Model
#'
#' Generate predictions from a fitted PLS model.
#'
#' @param pls_result Result from fit_pls_multivariate.
#' @param newdata Optional new data matrix for prediction.
#' @param ncomp Number of components to use. Default uses optimal.
#'
#' @return Matrix of predictions.
#'
#' @examples
#' \donttest{
#' if (requireNamespace("pls", quietly = TRUE)) {
#'   set.seed(123)
#'   n <- 50
#'   p <- 10
#'   X <- matrix(rnorm(n * p), n, p)
#'   colnames(X) <- paste0("X", 1:p)
#'   Y <- X[, 1:3] %*% diag(c(1, 0.5, 0.3)) + matrix(rnorm(n * 3, 0, 0.5), n, 3)
#'   colnames(Y) <- paste0("Y", 1:3)
#'
#'   result <- fit_pls_multivariate(X, Y, max_components = 5)
#'   preds <- predict_pls(result)
#'   dim(preds)
#' }
#' }
#'
#' @export
predict_pls <- function(pls_result, newdata = NULL, ncomp = NULL) {

    if (is.null(ncomp)) {
        ncomp <- pls_result$optimal_ncomp
    }

    model <- pls_result$model

    if (is.null(newdata)) {
        pred_arr <- stats::predict(model, ncomp = ncomp)
    } else {
        pred_arr <- stats::predict(model, newdata = newdata, ncomp = ncomp)
    }

    pred_mat <- matrix(
        pred_arr[, , 1L, drop = TRUE],
        nrow = dim(pred_arr)[1L],
        ncol = dim(pred_arr)[2L]
    )

    colnames(pred_mat) <- colnames(pred_arr)

    pred_mat
}
