## ----setup, include=FALSE-----------------------------------------------------
knitr::opts_chunk$set(echo = TRUE, fig.width = 6, fig.height = 4)

## ----ot-intuition, fig.height=5-----------------------------------------------
# Two simple 1D distributions
source_dist <- c(0.4, 0.1, 0.4, 0.1)
target_dist <- c(0.1, 0.3, 0.1, 0.3, 0.2)

oldpar <- par(mfrow = c(1, 2), mar = c(4, 4, 3, 1))
barplot(source_dist, col = "steelblue", main = "Source distribution",
        names.arg = seq_along(source_dist), ylim = c(0, 0.5),
        xlab = "Bin", ylab = "Mass")
barplot(target_dist, col = "tomato", main = "Target distribution",
        names.arg = seq_along(target_dist), ylim = c(0, 0.5),
        xlab = "Bin", ylab = "Mass")
par(mfrow = oldpar)

## ----tensor-illustration, fig.height=4----------------------------------------
oldpar <- par(mfrow = c(1, 3), mar = c(2, 2, 3, 1))

# Vector (order 1)
barplot(c(3, 1, 4, 1, 5), col = "steelblue", main = "Order 1: Vector")

# Matrix (order 2)
mat <- matrix(c(1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6), nrow = 3)
image(mat, col = gray((0:255) / 255), axes = FALSE, main = "Order 2: Matrix")

# 3D tensor (show one slice)
arr <- array(0, dim = c(3, 4, 2))
arr[,,1] <- matrix(c(1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6), nrow = 3)
arr[,,2] <- matrix(c(6, 5, 4, 3, 5, 4, 3, 2, 4, 3, 2, 1), nrow = 3)
image(arr[,,1], col = gray((0:255) / 255), axes = FALSE,
      main = "Order 3: Tensor\n(slice 1)")

par(mfrow = oldpar)

## ----quickstart, message=FALSE------------------------------------------------
library("otTensor")
library("rTensor")

## ----create-tensors-----------------------------------------------------------
# Source: a 4 x 5 matrix
arrX <- matrix(0, nrow = 4, ncol = 5)
for (i in 1:4) {
    for (j in 1:5) {
        arrX[i, j] <- i + j
    }
}

# Target: a 6 x 7 matrix (different size is OK)
arrY <- matrix(0, nrow = 6, ncol = 7)
for (i in 1:6) {
    for (j in 1:7) {
        arrY[i, j] <- i + j
    }
}

# Convert to Tensor objects
X <- as.tensor(arrX)
Y <- as.tensor(arrY)

## ----set-f--------------------------------------------------------------------
f <- c(1, 2)

## ----run-ott------------------------------------------------------------------
result <- OTT(X = X, Y = Y, f = f,
              num.sample = 500, num.iter = 100)

## ----inspect-results----------------------------------------------------------
# Transport plan dimensions
cat("Transport plan 1:", dim(result$Ts[[1]]), "\n")
cat("Transport plan 2:", dim(result$Ts[[2]]), "\n")

## ----visualize-results, fig.height=5, fig.width=6-----------------------------
.show_matrix <- function(mat, main = "") {
    mat_rev <- t(apply(mat, 2, rev))
    image(mat_rev, col = gray((0:255) / 255),
          xaxt = "n", yaxt = "n",
          xlab = "", ylab = "", axes = FALSE, main = main)
}

oldpar <- par(mfrow = c(2, 2), mar = c(2, 2, 3, 1))
.show_matrix(arrX, main = "Source (X)")
.show_matrix(arrY, main = "Target (Y)")
.show_matrix(result$Ts[[1]], main = "Transport Plan 1\n(rows)")
.show_matrix(result$Ts[[2]], main = "Transport Plan 2\n(columns)")
par(mfrow = oldpar)

## ----sessionInfo, echo=FALSE--------------------------------------------------
sessionInfo()

