Powered by AppSignal & Oban Pro

Iris Species Classification: A Comprehensive R Guide for Beginners

iris_classification_r_guide.livemd

Iris Species Classification: A Comprehensive R Guide for Beginners

Mix.install([
  {:rx, "~> 0.1"},
  {:kino, "~> 0.19.0"},
  {:plotly_ex, "~> 0.1.0"}
])

Source and License

This notebook is an Elixir Livebook port of “A Comprehensive R Guide for Beginners” by Adil Shamim, published on Kaggle:

https://www.kaggle.com/code/adilshamim8/a-comprehensive-r-guide-for-beginners

The original work is released under the Apache License, Version 2.0.

> Copyright 2024 Adil Shamim > > Licensed under the Apache License, Version 2.0 (the “License”); you may not use this file except in compliance with the License. You may obtain a copy of the License at > > http://www.apache.org/licenses/LICENSE-2.0 > > Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

This port adapts the original R code to run through the Rx library from Elixir, restructures cells to fit the Livebook format, and adds Elixir-side coordination. The statistical analysis, R code, and narrative are derived from the original work.

Introduction

Welcome to this comprehensive guide on using R for machine learning with the classic Iris dataset, run from Elixir via the Rx library. Rx drives a persistent Rscript process — the BEAM coordinates the analysis while all statistical computation happens in R.

The Iris dataset contains measurements for 150 iris flowers from three species:

  • Iris setosa
  • Iris versicolor
  • Iris virginica

The goal is to develop models that predict species from sepal and petal measurements, and to compare their performance.

Table of Contents

  1. Setup and Data Loading
  2. Data Overview and Initial Exploration
  3. Exploratory Data Analysis
  4. Data Preprocessing
  5. Model Building and Evaluation
  6. Model Comparison
  7. Feature Importance
  8. Advanced: Feature Space Visualization
  9. Conclusion

Required R Packages

For repeatable package setup, prefer an renv.lock committed with the notebook project and start the process backend with Rx.renv_init/2. This notebook has a large, package-heavy CRAN footprint, so restore: true should be an explicit choice in a controlled project/cache location.

# :ok = Rx.renv_init("path/to/project", restore: true)

Backend Setup

This notebook can use the external Rscript process backend or the experimental native backend selected in the setup cell. All R globals persist across cells within a session. Run cells in order from top to bottom for correct results.

defmodule IrisGuide do
  @moduledoc false

  def init(backend \\ :process) do
    case backend do
      :process ->
        :ok = Rx.use_backend(:process, r_binary: "Rscript", lib_paths: [])

      :native ->
        r_home = System.cmd("R", ["RHOME"]) |> elem(0) |> String.trim()
        lib_r_path =
          [Path.join([r_home, "lib", "libR.so"]), Path.join([r_home, "lib", "libR.dylib"])]
          |> Enum.find(&File.exists?/1)
        :ok = Rx.use_backend(:native, r_home: r_home, lib_r_path: lib_r_path, lib_paths: [])
    end
  end

  # Evaluate R source and print captured output. Returns the R result handle.
  def r(source, globals \\ %{}) do
    captured = Rx.eval(source, globals, capture: true)
    unless captured.stdout == "", do: IO.puts(String.trim_trailing(captured.stdout))

    unless captured.messages == "" do
      IO.puts("[R messages]")
      IO.puts(String.trim_trailing(captured.messages))
    end

    unless captured.warnings == "" do
      IO.puts("[R warnings]")
      IO.puts(String.trim_trailing(captured.warnings))
    end

    {captured.result, captured.globals}
  end

  # Evaluate R source and render captured plot output via Kino.
  def plot(source, globals \\ %{}, opts \\ []) do
    opts = Keyword.merge([width: 800, height: 520, res: 96], opts)
    Rx.Kino.plot(source, globals, opts)
  end
end
:ok = IrisGuide.init(:native)
#:ok = IrisGuide.init()
Rx.backend()

1. Setup and Data Loading

Load required packages, define a consistent my_theme for all visualizations, set the random seed, and prepare the Iris dataset with tidy column names.

{_result, r_globals} = IrisGuide.r(~S"""
suppressMessages(suppressWarnings({
  library(tidyverse)
  library(caret)
  library(rpart)
  library(rpart.plot)
  library(randomForest)
  library(e1071)
  library(class)
  library(corrplot)
  library(pROC)
  library(gridExtra)
  library(viridis)
}))

my_theme <- theme_minimal() +
  theme(
    plot.title    = element_text(face = "bold", size = 14, hjust = 0.5),
    plot.subtitle = element_text(hjust = 0.5, size = 12, color = "#666666"),
    axis.title    = element_text(face = "bold", size = 11),
    axis.text     = element_text(size = 10),
    legend.title  = element_text(face = "bold"),
    legend.position = "right",
    panel.grid.minor = element_blank(),
    panel.grid.major = element_line(color = "#EEEEEE"),
    panel.border  = element_rect(color = "#DDDDDD", fill = NA)
  )

set.seed(8888)

iris_df <- datasets::iris %>%
  rename(
    sepal_length = Sepal.Length,
    sepal_width  = Sepal.Width,
    petal_length = Petal.Length,
    petal_width  = Petal.Width,
    species      = Species
  )

cat("Dataset: Iris Species Classification\n")
cat("Dimensions:", dim(iris_df)[1], "observations,", dim(iris_df)[2], "variables\n\n")
cat("Species distribution:\n")
print(table(iris_df$species))
cat("\nFirst 6 rows:\n")
print(head(iris_df))
""", %{})
:ok

R ships the Iris dataset in the datasets package. It contains 150 observations with four measurements per flower (sepal and petal dimensions in centimetres) and a species label.

2. Data Overview and Initial Exploration

Structure

IrisGuide.r(~S"""
str(iris_df)
""", r_globals)
:ok

Summary Statistics

IrisGuide.r(~S"""
print(summary(iris_df))
""", r_globals)
:ok

Missing Values

IrisGuide.r(~S"""
cat("Missing values per column:\n")
print(colSums(is.na(iris_df)))
""", r_globals)
:ok

No missing values — the dataset is clean and ready for analysis.

Class Distribution

{_result, r_globals} = IrisGuide.r(~S"""
species_counts <- iris_df %>%
  count(species) %>%
  mutate(percentage = n / sum(n) * 100)

cat("Class distribution:\n")
print(species_counts)
""", r_globals)
:ok
IrisGuide.plot(~S"""
species_counts <- iris_df %>%
  count(species) %>%
  mutate(percentage = n / sum(n) * 100)

ggplot(species_counts, aes(x = species, y = n, fill = species)) +
  geom_col() +
  geom_text(aes(label = paste0(n, " (", round(percentage, 1), "%)")),
            vjust = -0.5, size = 4) +
  labs(title = "Class Distribution in Iris Dataset", x = "Species", y = "Count") +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  theme(legend.position = "none") +
  ylim(0, max(species_counts$n) * 1.15)
""", r_globals, width: 560, height: 400)

The dataset is perfectly balanced — 50 observations per species (33.3% each).

3. Exploratory Data Analysis

3.1 Univariate Analysis — Feature Distributions

IrisGuide.plot(~S"""
create_histogram <- function(data, variable, title, fill_color, binwidth = NULL) {
  if (is.null(binwidth)) {
    binwidth <- (max(data[[variable]]) - min(data[[variable]])) / 15
  }
  ggplot(data, aes(x = .data[[variable]])) +
    geom_histogram(aes(y = after_stat(density)), binwidth = binwidth,
                   fill = fill_color, color = "white", alpha = 0.7) +
    geom_density(alpha = 0.2, fill = "gray") +
    labs(title = title, x = gsub("_", " ", toupper(variable)), y = "Density") +
    my_theme
}

p1 <- create_histogram(iris_df, "sepal_length", "Sepal Length", "#4E79A7")
p2 <- create_histogram(iris_df, "sepal_width",  "Sepal Width",  "#F28E2B")
p3 <- create_histogram(iris_df, "petal_length", "Petal Length", "#E15759")
p4 <- create_histogram(iris_df, "petal_width",  "Petal Width",  "#76B7B2")

grid.arrange(p1, p2, p3, p4, ncol = 2,
             top = grid::textGrob("Distributions of Iris Features",
                                  gp = grid::gpar(fontsize = 16, fontface = "bold")))
""", r_globals, height: 620)

Petal length and petal width show bimodal distributions, suggesting the setosa species is clearly separated from the other two.

3.2 Feature Distributions by Species

IrisGuide.plot(~S"""
create_violin_plot <- function(data, variable, title) {
  ggplot(data, aes(x = species, y = .data[[variable]], fill = species)) +
    geom_violin(alpha = 0.7) +
    geom_boxplot(width = 0.1, alpha = 0.5, outlier.shape = NA) +
    geom_jitter(width = 0.1, alpha = 0.5, size = 1.5) +
    labs(title = title, x = "Species", y = gsub("_", " ", toupper(variable))) +
    scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
    my_theme +
    theme(legend.position = "none")
}

v1 <- create_violin_plot(iris_df, "sepal_length", "Sepal Length by Species")
v2 <- create_violin_plot(iris_df, "sepal_width",  "Sepal Width by Species")
v3 <- create_violin_plot(iris_df, "petal_length", "Petal Length by Species")
v4 <- create_violin_plot(iris_df, "petal_width",  "Petal Width by Species")

grid.arrange(v1, v2, v3, v4, ncol = 2,
             top = grid::textGrob("Feature Distributions by Species",
                                  gp = grid::gpar(fontsize = 16, fontface = "bold")))
""", r_globals, height: 660)

Setosa is clearly distinct in petal dimensions. Versicolor and virginica overlap more, especially in sepal width.

3.3 Scatter Plots with Confidence Ellipses

IrisGuide.plot(~S"""
ggplot(iris_df, aes(x = sepal_length, y = sepal_width, color = species)) +
  geom_point(size = 3, alpha = 0.7) +
  stat_ellipse(aes(fill = species), geom = "polygon", alpha = 0.1, level = 0.95) +
  labs(title = "Sepal Width vs. Sepal Length",
       x = "Sepal Length (cm)", y = "Sepal Width (cm)",
       color = "Species", fill = "Species") +
  scale_color_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  annotate("text",
           x = max(iris_df$sepal_length) - 0.5, y = min(iris_df$sepal_width) + 0.2,
           label = "Ellipses show 95% confidence regions",
           hjust = 1, size = 3, color = "darkgray", fontface = "italic")
""", r_globals, width: 720, height: 480)
IrisGuide.plot(~S"""
ggplot(iris_df, aes(x = petal_length, y = petal_width, color = species)) +
  geom_point(size = 3, alpha = 0.7) +
  stat_ellipse(aes(fill = species), geom = "polygon", alpha = 0.1, level = 0.95) +
  labs(title = "Petal Width vs. Petal Length",
       x = "Petal Length (cm)", y = "Petal Width (cm)",
       color = "Species", fill = "Species") +
  scale_color_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  annotate("text",
           x = max(iris_df$petal_length) - 0.5, y = min(iris_df$petal_width) + 0.2,
           label = "Ellipses show 95% confidence regions",
           hjust = 1, size = 3, color = "darkgray", fontface = "italic")
""", r_globals, width: 720, height: 480)

Petal dimensions provide nearly perfect linear separation, especially for setosa. The sepal scatter shows more overlap between versicolor and virginica.

3.4 Correlation Analysis

IrisGuide.plot(~S"""
cor_matrix <- cor(iris_df %>% select(-species))
corrplot(cor_matrix,
         method      = "circle",
         type        = "upper",
         tl.col      = "black",
         tl.srt      = 45,
         addCoef.col = "black",
         col         = colorRampPalette(c("#4477AA", "white", "#EE6677"))(200),
         diag        = FALSE,
         title       = "Correlation Matrix of Iris Features",
         mar         = c(0, 0, 1, 0))
""", r_globals, width: 580, height: 520)

Petal length and petal width are highly correlated (r ≈ 0.96). Both petal dimensions also correlate strongly with sepal length, which is useful for classification but can introduce multicollinearity in linear models.

3.5 Pair Plot

If the GGally package is installed this section renders an enhanced pair plot with per-class correlation coefficients and density diagonals. Otherwise it falls back to the base R pairs() function.

IrisGuide.plot(~S"""
ggally_available <- suppressWarnings(require(GGally, quietly = TRUE))
if (ggally_available) {
  p <- ggpairs(iris_df,
               columns = 1:4,
               aes(color = species, alpha = 0.7),
               upper = list(continuous = "cor"),
               lower = list(continuous = "points"),
               diag  = list(continuous = "densityDiag"),
               title = "Pair Plot for Iris Dataset") +
    scale_color_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
    scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
    theme_minimal() +
    theme(axis.text = element_text(size = 7))
  print(p)
} else {
  pairs(iris_df[1:4],
        col  = as.numeric(iris_df$species),
        pch  = 19,
        main = "Scatterplot Matrix for Iris Dataset")
  par(xpd = TRUE)
  legend("bottomright",
         legend = levels(iris_df$species),
         col    = 1:3,
         pch    = 19,
         bty    = "n")
}
""", r_globals, width: 820, height: 720)

4. Data Preprocessing

4.1 Outlier Detection

Z-scores are computed within each species group. Observations with |z| > 2.5 on any feature are flagged as potential outliers.

{_result, r_globals} = IrisGuide.r(~S"""
iris_z <- iris_df %>%
  group_by(species) %>%
  mutate(across(where(is.numeric), ~scale(.)[, 1], .names = "{col}_z")) %>%
  ungroup()

outliers <- iris_z %>%
  select(contains("_z")) %>%
  mutate(across(everything(), ~abs(.) > 2.5)) %>%
  mutate(outlier_count = rowSums(.)) %>%
  bind_cols(iris_df) %>%
  dplyr::filter(outlier_count > 0)

cat("Potential outliers detected:", nrow(outliers), "observations\n")
if (nrow(outliers) > 0) {
  print(outliers %>%
    select(species, sepal_length, sepal_width, petal_length, petal_width, outlier_count))
}
""", r_globals)
:ok
IrisGuide.plot(~S"""
ggplot(iris_z, aes(x = species, y = sepal_length_z, color = species)) +
  geom_boxplot(outlier.shape = 8, outlier.size = 3) +
  geom_jitter(width = 0.2, alpha = 0.5) +
  geom_hline(yintercept = c(-2.5, 2.5), linetype = "dashed", color = "red", alpha = 0.7) +
  labs(title    = "Z-Scores of Sepal Length by Species",
       subtitle = "Dashed lines indicate outlier threshold (|z| = 2.5)",
       x = "Species", y = "Z-Score") +
  scale_color_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  theme(legend.position = "none")
""", r_globals, width: 640, height: 440)

4.2 Feature Engineering

New features are derived from the original measurements to capture geometric relationships between sepal and petal dimensions.

{_result, r_globals} = IrisGuide.r(~S"""
iris_engineered <- iris_df %>%
  mutate(
    sepal_area                  = sepal_length * sepal_width,
    petal_area                  = petal_length * petal_width,
    sepal_petal_ratio           = (sepal_length * sepal_width) / (petal_length * petal_width),
    sepal_length_to_width_ratio = sepal_length / sepal_width,
    petal_length_to_width_ratio = petal_length / petal_width
  )

cat("Engineered dataset — first 6 rows:\n")
print(head(iris_engineered))
""", r_globals)
:ok
IrisGuide.plot(~S"""
ggplot(iris_engineered, aes(x = species, y = petal_area, fill = species)) +
  geom_violin(alpha = 0.7) +
  geom_boxplot(width = 0.1, alpha = 0.7) +
  labs(title    = "Petal Area by Species",
       subtitle = "Engineered feature: petal_length × petal_width",
       x = "Species", y = "Petal Area (cm²)") +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  theme(legend.position = "none")
""", r_globals, width: 640, height: 440)

petal_area provides excellent separation between all three species and is a strong candidate for the classifiers.

4.3 Data Splitting with Stratification

An 80/20 stratified split ensures each species is proportionally represented in both the training and test sets.

{_result, r_globals} = IrisGuide.r(~S"""
train_index    <- createDataPartition(iris_df$species, p = 0.8, list = FALSE)
train_data     <- iris_df[train_index, ]
test_data      <- iris_df[-train_index, ]
train_data_eng <- iris_engineered[train_index, ]
test_data_eng  <- iris_engineered[-train_index, ]

split_comparison <- data.frame(
  Species          = names(table(train_data$species)),
  Training_Count   = as.vector(table(train_data$species)),
  Testing_Count    = as.vector(table(test_data$species)),
  Training_Percent = as.vector(table(train_data$species)) / nrow(train_data) * 100,
  Testing_Percent  = as.vector(table(test_data$species)) / nrow(test_data) * 100
)

cat("Train/test split (80/20 stratified):\n")
print(split_comparison, digits = 3)
""", r_globals)
:ok

5. Model Building and Evaluation

For each model we: (1) train on the training data, (2) predict on the test data, and (3) evaluate with a confusion matrix.

5.1 Decision Tree

{_result, r_globals} = IrisGuide.r(~S"""
dt_model <- rpart(species ~ ., data = train_data,
                  method  = "class",
                  control = rpart.control(cp = 0.01, minsplit = 5))
""", r_globals)

IrisGuide.plot(~S"""
rpart.plot(dt_model,
           extra       = 104,
           box.palette = "RdBu",
           shadow.col  = "gray80",
           nn          = TRUE,
           roundint    = FALSE,
           main        = "Decision Tree for Iris Classification")
""", r_globals, width: 840, height: 600)
{_result, r_globals} = IrisGuide.r(~S"""
dt_pred <- predict(dt_model, test_data, type = "class")
dt_prob <- predict(dt_model, test_data, type = "prob")
dt_cm   <- confusionMatrix(dt_pred, test_data$species)
print(dt_cm)
""", r_globals)
:ok

5.2 Random Forest

{_result, r_globals} = IrisGuide.r(~S"""
rf_model <- randomForest(species ~ .,
                         data       = train_data,
                         ntree      = 500,
                         mtry       = floor(sqrt(ncol(train_data) - 1)),
                         importance = TRUE,
                         proximity  = TRUE)
print(rf_model)
""", r_globals)
:ok
IrisGuide.plot(~S"""
plot(rf_model, main = "Random Forest: Error Rate vs Number of Trees")
legend("topright", colnames(rf_model$err.rate), lty = 1:3, col = 1:3)
""", r_globals, width: 700, height: 480)
{_result, r_globals} = IrisGuide.r(~S"""
rf_pred <- predict(rf_model, test_data)
rf_prob <- predict(rf_model, test_data, type = "prob")
rf_cm   <- confusionMatrix(rf_pred, test_data$species)
print(rf_cm)
""", r_globals)
:ok

5.3 Support Vector Machine with Parameter Tuning

The SVM uses a radial basis function (RBF) kernel. A grid search over gamma and cost finds the best hyperparameters via cross-validation.

{_result, r_globals} = IrisGuide.r(~S"""
svm_tune <- tune.svm(species ~ .,
                     data   = train_data,
                     kernel = "radial",
                     gamma  = 10^(-3:-1),
                     cost   = 10^(0:2))

cat("Best SVM parameters:\n")
print(svm_tune$best.parameters)

svm_model <- svm(species ~ .,
                 data        = train_data,
                 kernel      = "radial",
                 gamma       = svm_tune$best.parameters$gamma,
                 cost        = svm_tune$best.parameters$cost,
                 probability = TRUE)

svm_pred <- predict(svm_model, test_data)
svm_prob <- attr(predict(svm_model, test_data, probability = TRUE), "probabilities")
svm_cm   <- confusionMatrix(svm_pred, test_data$species)
print(svm_cm)
""", r_globals)
:ok

5.4 K-Nearest Neighbors with Optimal K

We search k = 1 … 15 and select the value that maximises test accuracy.

{_result, r_globals} = IrisGuide.r(~S"""
find_optimal_k <- function(train, test, max_k = 15) {
  k_values <- 1:max_k
  accuracy <- numeric(length(k_values))

  train_scaled <- scale(train %>% select(-species))
  test_scaled  <- scale(test %>% select(-species))

  for (i in seq_along(k_values)) {
    knn_pred    <- knn(train = train_scaled, test = test_scaled,
                       cl = train$species, k = k_values[i])
    accuracy[i] <- confusionMatrix(knn_pred, test$species)$overall["Accuracy"]
  }

  plot_data <- data.frame(K = k_values, Accuracy = accuracy)
  p <- ggplot(plot_data, aes(x = K, y = Accuracy)) +
    geom_line(color = "#4E79A7", linewidth = 1) +
    geom_point(color = "#4E79A7", size = 3) +
    geom_point(data = plot_data[which.max(accuracy), ],
               aes(x = K, y = Accuracy), color = "red", size = 4) +
    annotate("text",
             x     = plot_data$K[which.max(accuracy)],
             y     = max(accuracy) + 0.02,
             label = paste("Optimal k =", plot_data$K[which.max(accuracy)]),
             color = "red") +
    labs(title = "KNN Accuracy for Different Values of k",
         x = "Number of Neighbors (k)", y = "Accuracy") +
    my_theme +
    ylim(min(accuracy) - 0.05, max(accuracy) + 0.05)

  list(optimal_k = k_values[which.max(accuracy)], max_accuracy = max(accuracy), plot = p)
}

knn_optimization <- find_optimal_k(train_data, test_data)
cat("Optimal k:", knn_optimization$optimal_k, "\n")
cat("Accuracy at optimal k:", round(knn_optimization$max_accuracy, 4), "\n")
""", r_globals)
:ok
IrisGuide.plot(~S"""
print(knn_optimization$plot)
""", r_globals, width: 700, height: 460)
{_result, r_globals} = IrisGuide.r(~S"""
knn_optimal_k <- knn_optimization$optimal_k
train_scaled  <- scale(train_data %>% select(-species))
test_scaled   <- scale(test_data %>% select(-species))

knn_pred    <- knn(train = train_scaled, test = test_scaled,
                   cl = train_data$species, k = knn_optimal_k, prob = TRUE)
knn_prob_raw <- attr(knn_pred, "prob")

# Build a full probability matrix for the ROC section
knn_prob_matrix <- matrix(0, nrow = length(knn_pred), ncol = 3)
colnames(knn_prob_matrix) <- levels(train_data$species)
for (i in seq_along(knn_pred)) {
  idx <- which(levels(train_data$species) == knn_pred[i])
  knn_prob_matrix[i, idx] <- knn_prob_raw[i]
}

knn_cm <- confusionMatrix(knn_pred, test_data$species)
print(knn_cm)
""", r_globals)
:ok

5.5 XGBoost (Optional)

This section runs only if the R xgboost package is installed.

{_result, r_globals} = IrisGuide.r(~S"""
xgboost_available <- suppressWarnings(require(xgboost, quietly = TRUE))

if (xgboost_available) {
  xgb_train        <- model.matrix(~ . - 1, data = train_data %>% select(-species))
  xgb_test         <- model.matrix(~ . - 1, data = test_data %>% select(-species))
  xgb_train_labels <- as.numeric(train_data$species) - 1
  xgb_test_labels  <- as.numeric(test_data$species) - 1

  dtrain <- xgb.DMatrix(data = xgb_train, label = xgb_train_labels)
  dtest  <- xgb.DMatrix(data = xgb_test,  label = xgb_test_labels)

  params <- list(
    objective        = "multi:softprob",
    num_class        = 3,
    eta              = 0.1,
    max_depth        = 3,
    subsample        = 0.8,
    colsample_bytree = 0.8,
    min_child_weight = 1,
    gamma            = 0
  )

  xgb_model <- xgb.train(
    params                = params,
    data                  = dtrain,
    nrounds               = 100,
    watchlist             = list(train = dtrain, test = dtest),
    early_stopping_rounds = 10,
    verbose               = 0
  )

  xgb_prob <- predict(xgb_model, dtest, reshape = TRUE)
  xgb_pred <- levels(test_data$species)[max.col(xgb_prob)]
  xgb_pred <- factor(xgb_pred, levels = levels(test_data$species))

  xgb_cm <- confusionMatrix(xgb_pred, test_data$species)
  print(xgb_cm)
} else {
  cat("xgboost is not installed — skipping.\n")
  cat("Install with: install.packages('xgboost')\n")
}
""", r_globals)
:ok
IrisGuide.plot(~S"""
xgboost_available <- suppressWarnings(require(xgboost, quietly = TRUE))
if (xgboost_available && exists("xgb_model")) {
  xgb_train <- model.matrix(~ . - 1, data = train_data %>% select(-species))
  importance_xgb <- xgb.importance(feature_names = colnames(xgb_train), model = xgb_model)
  xgb.plot.importance(importance_xgb, top_n = 10, main = "XGBoost Feature Importance")
}
""", r_globals, width: 680, height: 440)

6. Model Comparison

6.1 Accuracy Comparison Table and Plot

{_result, r_globals} = IrisGuide.r(~S"""
models     <- c("Decision Tree", "Random Forest", "SVM", "KNN")
accuracies <- c(
  dt_cm$overall["Accuracy"],
  rf_cm$overall["Accuracy"],
  svm_cm$overall["Accuracy"],
  knn_cm$overall["Accuracy"]
)

if (exists("xgb_cm")) {
  models     <- c(models, "XGBoost")
  accuracies <- c(accuracies, xgb_cm$overall["Accuracy"])
}

model_comparison <- data.frame(Model = models, Accuracy = accuracies)
model_comparison <- model_comparison[order(model_comparison$Accuracy, decreasing = TRUE), ]

cat("Model Accuracy Comparison:\n")
print(model_comparison, row.names = FALSE, digits = 4)
cat("\nBest model:", model_comparison$Model[1],
    "with accuracy:", round(model_comparison$Accuracy[1], 4), "\n")
""", r_globals)
:ok
IrisGuide.plot(~S"""
ggplot(model_comparison, aes(x = reorder(Model, Accuracy), y = Accuracy, fill = Model)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = sprintf("%.3f", Accuracy)), vjust = -0.5, size = 4) +
  labs(title    = "Model Accuracy Comparison",
       subtitle = "Performance on Test Set",
       x = "Model", y = "Accuracy") +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  theme(axis.text.x = element_text(angle = 45, hjust = 1),
        legend.position = "none") +
  ylim(0, 1.05)
""", r_globals, width: 640, height: 460)

6.2 ROC Curves

One-vs-rest ROC curves are plotted for each model and species class. The legend shows each model’s mean AUC across all three classes.

IrisGuide.plot(~S"""
calculate_multiclass_roc <- function(probs, actual_class) {
  actual_class <- factor(actual_class)
  roc_list     <- list()
  auc_values   <- numeric(length(levels(actual_class)))
  for (i in seq_along(levels(actual_class))) {
    binary_labels <- ifelse(actual_class == levels(actual_class)[i], 1, 0)
    roc_obj       <- roc(binary_labels, probs[, i], quiet = TRUE)
    roc_list[[i]] <- roc_obj
    auc_values[i] <- auc(roc_obj)
  }
  names(roc_list)   <- levels(actual_class)
  names(auc_values) <- levels(actual_class)
  list(roc = roc_list, auc = auc_values, mean_auc = mean(auc_values))
}

model_colors <- c(
  "Decision Tree" = "#4E79A7",
  "Random Forest" = "#F28E2B",
  "SVM"           = "#E15759",
  "KNN"           = "#76B7B2",
  "XGBoost"       = "#59A14F"
)

par(mar = c(5, 5, 4, 2) + 0.1)
plot(NULL, NULL, xlim = c(1, 0), ylim = c(0, 1),
     xlab = "Specificity", ylab = "Sensitivity",
     main = "ROC Curves for Different Models", las = 1)
abline(0, 1, lty = 2, col = "gray")

mean_auc_values <- list()
probs_list <- list(
  "Decision Tree" = dt_prob,
  "Random Forest" = rf_prob,
  "SVM"           = svm_prob,
  "KNN"           = knn_prob_matrix
)
if (exists("xgb_prob")) probs_list[["XGBoost"]] <- xgb_prob

for (model_name in names(probs_list)) {
  roc_result <- calculate_multiclass_roc(probs_list[[model_name]], test_data$species)
  for (i in seq_along(roc_result$roc)) {
    lines(roc_result$roc[[i]],
          col = adjustcolor(model_colors[model_name], alpha.f = 0.75),
          lwd = 2)
  }
  mean_auc_values[[model_name]] <- roc_result$mean_auc
}

legend_text <- paste0(names(mean_auc_values), " (mean AUC: ",
                      round(unlist(mean_auc_values), 3), ")")
legend("bottomright", legend = legend_text,
       col = model_colors[names(mean_auc_values)],
       lwd = 2, cex = 0.82, bg = "white")
""", r_globals, width: 740, height: 560)

7. Feature Importance

7.1 Random Forest Feature Importance

{_result, r_globals} = IrisGuide.r(~S"""
rf_importance <- importance(rf_model)
importance_df <- data.frame(
  Feature           = rownames(rf_importance),
  Accuracy_Decrease = rf_importance[, "MeanDecreaseAccuracy"],
  Gini_Decrease     = rf_importance[, "MeanDecreaseGini"]
) %>%
  arrange(desc(Gini_Decrease))

cat("Random Forest feature importance (sorted by Gini decrease):\n")
print(importance_df, row.names = FALSE, digits = 3)
""", r_globals)
:ok
IrisGuide.plot(~S"""
p_gini <- ggplot(importance_df,
                 aes(x = reorder(Feature, Gini_Decrease), y = Gini_Decrease, fill = Feature)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = round(Gini_Decrease, 2)), hjust = -0.2, size = 3.5) +
  labs(title    = "RF Importance (Gini)",
       subtitle = "Mean Decrease in Gini Index",
       x = NULL, y = "Mean Decrease Gini") +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  theme(legend.position = "none") +
  coord_flip() +
  ylim(0, max(importance_df$Gini_Decrease) * 1.2)

p_acc <- ggplot(importance_df,
                aes(x = reorder(Feature, Accuracy_Decrease), y = Accuracy_Decrease, fill = Feature)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = round(Accuracy_Decrease, 2)), hjust = -0.2, size = 3.5) +
  labs(title    = "RF Importance (Accuracy)",
       subtitle = "Mean Decrease in Accuracy",
       x = NULL, y = "Mean Decrease Accuracy") +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  theme(legend.position = "none") +
  coord_flip() +
  ylim(0, max(importance_df$Accuracy_Decrease) * 1.2)

grid.arrange(p_gini, p_acc, ncol = 2,
             top = grid::textGrob("Random Forest Feature Importance",
                                  gp = grid::gpar(fontsize = 15, fontface = "bold")))
""", r_globals, width: 920, height: 420)

7.2 Decision Tree Feature Importance

IrisGuide.plot(~S"""
dt_importance <- data.frame(
  Feature    = names(dt_model$variable.importance),
  Importance = dt_model$variable.importance
) %>%
  arrange(desc(Importance))

ggplot(dt_importance, aes(x = reorder(Feature, Importance), y = Importance, fill = Feature)) +
  geom_bar(stat = "identity") +
  geom_text(aes(label = round(Importance, 2)), hjust = -0.2, size = 3.5) +
  labs(title = "Feature Importance from Decision Tree",
       x = NULL, y = "Importance") +
  scale_fill_viridis_d(option = "plasma", begin = 0.2, end = 0.8) +
  my_theme +
  theme(legend.position = "none") +
  coord_flip() +
  ylim(0, max(dt_importance$Importance) * 1.2)
""", r_globals, width: 660, height: 380)

Both models agree: petal length and petal width are the most discriminative features, with sepal dimensions contributing less to classification accuracy.

8. Advanced: Feature Space Visualization

The top three features from the Random Forest importance ranking are visualised in an interactive 3D scatter plot. This section requires the R plotly package and the plotly_ex Elixir dependency.

Requires: section 7 must have been run first so that importance_df exists in R.

{r_plotly_available, _} =
  Rx.eval(
    "base::requireNamespace('plotly', quietly = TRUE)",
    %{}
  )

if Rx.decode(r_plotly_available) do
  {r_plot, _r_globals} =
    Rx.eval(
      ~S"""
      suppressMessages(library(plotly))
      top_features <- importance_df$Feature[1:3]

      plot_ly(
        data   = iris_df,
        x      = ~get(top_features[1]),
        y      = ~get(top_features[2]),
        z      = ~get(top_features[3]),
        color  = ~species,
        colors = c("#4E79A7", "#F28E2B", "#E15759"),
        type   = "scatter3d",
        mode   = "markers",
        marker = list(size = 5, opacity = 0.8)
      ) %>%
        layout(
          title = "3D Visualization of Top 3 Features",
          scene = list(
            xaxis = list(title = gsub("_", " ", top_features[1])),
            yaxis = list(title = gsub("_", " ", top_features[2])),
            zaxis = list(title = gsub("_", " ", top_features[3]))
          )
        )
      """,
      r_globals
    )

  {:ok, fig} = Rx.Plotly.from_r(r_plot)
  Plotly.show(fig)
else
  Kino.Text.new(
    "R `plotly` package not installed.\nInstall with: install.packages('plotly')"
  )
end

9. Conclusion

This notebook demonstrated a complete machine learning workflow on the Iris dataset using R from Elixir via Rx:

1. Data Exploration and Visualization

  • The dataset contains 150 balanced samples (50 per species) with four numeric features.
  • Petal dimensions provide near-perfect class separation, especially for setosa.
  • Sepal dimensions show more overlap between versicolor and virginica.

2. Feature Engineering

  • Derived features such as petal_area and sepal_petal_ratio add geometric context.
  • Feature importance analysis from Random Forest and Decision Tree confirmed that petal dimensions are the most discriminative.

3. Model Building and Evaluation

  • Four classifiers were implemented: Decision Tree, Random Forest, SVM (with RBF kernel + grid-search tuning), and KNN (with optimal k search).
  • All models achieve high accuracy (≥ 0.93) on this dataset — a testament to how well-separated the classes are in feature space.
  • XGBoost is included as an optional fifth model when the R package is available.

4. Key Takeaways

  • The Iris dataset shows how effective simple machine learning models can be when features are discriminative.
  • Petal measurements are sufficient to classify setosa perfectly; the challenge lies in distinguishing versicolor from virginica.
  • Rx lets you leverage the full R ecosystem from Elixir: tidy data manipulation, ggplot2 visualizations, caret workflows, and any CRAN package — all coordinated from a Livebook notebook.