Keras 3 is a deep learning framework works with TensorFlow, JAX, and PyTorch interchangeably. This notebook will walk you through key Keras 3 workflows.
Let’s start by installing Keras 3:
We’re going to be using the tensorflow backend here – but you can
edit the string below to "jax" or "torch" and
hit “Restart runtime”, and the whole notebook will run just the same!
This entire guide is backend-agnostic.
Let’s start with the Hello World of ML: training a convnet to classify MNIST digits.
Here’s the data:
# Load the data and split it between train and test sets
c(c(x_train, y_train), c(x_test, y_test)) %<-% keras3::dataset_mnist()
# Scale images to the [0, 1] range
x_train <- x_train / 255
x_test <- x_test / 255
# Make sure images have shape (28, 28, 1)
x_train <- array_reshape(x_train, c(-1, 28, 28, 1))
x_test <- array_reshape(x_test, c(-1, 28, 28, 1))
dim(x_train)## [1] 60000    28    28     1## [1] 10000    28    28     1Here’s our model.
Different model-building options that Keras offers include:
# Model parameters
num_classes <- 10
input_shape <- c(28, 28, 1)
model <- keras_model_sequential(input_shape = input_shape)
model |>
  layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") |>
  layer_conv_2d(filters = 64, kernel_size = c(3, 3), activation = "relu") |>
  layer_max_pooling_2d(pool_size = c(2, 2)) |>
  layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") |>
  layer_conv_2d(filters = 128, kernel_size = c(3, 3), activation = "relu") |>
  layer_global_average_pooling_2d() |>
  layer_dropout(rate = 0.5) |>
  layer_dense(units = num_classes, activation = "softmax")Here’s our model summary:
## [1mModel: "sequential"[0m
## ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
## ┃[1m [0m[1mLayer (type)                   [0m[1m [0m┃[1m [0m[1mOutput Shape          [0m[1m [0m┃[1m [0m[1m      Param #[0m[1m [0m┃
## ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
## │ conv2d ([38;5;33mConv2D[0m)                 │ ([38;5;45mNone[0m, [38;5;34m26[0m, [38;5;34m26[0m, [38;5;34m64[0m)     │           [38;5;34m640[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ conv2d_1 ([38;5;33mConv2D[0m)               │ ([38;5;45mNone[0m, [38;5;34m24[0m, [38;5;34m24[0m, [38;5;34m64[0m)     │        [38;5;34m36,928[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ max_pooling2d ([38;5;33mMaxPooling2D[0m)    │ ([38;5;45mNone[0m, [38;5;34m12[0m, [38;5;34m12[0m, [38;5;34m64[0m)     │             [38;5;34m0[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ conv2d_2 ([38;5;33mConv2D[0m)               │ ([38;5;45mNone[0m, [38;5;34m10[0m, [38;5;34m10[0m, [38;5;34m128[0m)    │        [38;5;34m73,856[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ conv2d_3 ([38;5;33mConv2D[0m)               │ ([38;5;45mNone[0m, [38;5;34m8[0m, [38;5;34m8[0m, [38;5;34m128[0m)      │       [38;5;34m147,584[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ global_average_pooling2d        │ ([38;5;45mNone[0m, [38;5;34m128[0m)            │             [38;5;34m0[0m │
## │ ([38;5;33mGlobalAveragePooling2D[0m)        │                        │               │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dropout ([38;5;33mDropout[0m)               │ ([38;5;45mNone[0m, [38;5;34m128[0m)            │             [38;5;34m0[0m │
## ├─────────────────────────────────┼────────────────────────┼───────────────┤
## │ dense ([38;5;33mDense[0m)                   │ ([38;5;45mNone[0m, [38;5;34m10[0m)             │         [38;5;34m1,290[0m │
## └─────────────────────────────────┴────────────────────────┴───────────────┘
## [1m Total params: [0m[38;5;34m260,298[0m (1016.79 KB)
## [1m Trainable params: [0m[38;5;34m260,298[0m (1016.79 KB)
## [1m Non-trainable params: [0m[38;5;34m0[0m (0.00 B)We use the compile() method to specify the optimizer,
loss function, and the metrics to monitor. Note that with the JAX and
TensorFlow backends, XLA compilation is turned on by default.
model |> compile(
  optimizer = "adam",
  loss = "sparse_categorical_crossentropy",
  metrics = list(
    metric_sparse_categorical_accuracy(name = "acc")
  )
)Let’s train and evaluate the model. We’ll set aside a validation split of 15% of the data during training to monitor generalization on unseen data.
batch_size <- 128
epochs <- 10
callbacks <- list(
  callback_model_checkpoint(filepath="model_at_epoch_{epoch}.keras"),
  callback_early_stopping(monitor="val_loss", patience=2)
)
model |> fit(
  x_train, y_train,
  batch_size = batch_size,
  epochs = epochs,
  validation_split = 0.15,
  callbacks = callbacks
)## Epoch 1/10
## 399/399 - 8s - 19ms/step - acc: 0.7477 - loss: 0.7456 - val_acc: 0.9654 - val_loss: 0.1189
## Epoch 2/10
## 399/399 - 3s - 7ms/step - acc: 0.9377 - loss: 0.2072 - val_acc: 0.9764 - val_loss: 0.0786
## Epoch 3/10
## 399/399 - 3s - 7ms/step - acc: 0.9567 - loss: 0.1474 - val_acc: 0.9820 - val_loss: 0.0623
## Epoch 4/10
## 399/399 - 3s - 7ms/step - acc: 0.9648 - loss: 0.1182 - val_acc: 0.9860 - val_loss: 0.0490
## Epoch 5/10
## 399/399 - 3s - 7ms/step - acc: 0.9707 - loss: 0.1007 - val_acc: 0.9872 - val_loss: 0.0475
## Epoch 6/10
## 399/399 - 3s - 7ms/step - acc: 0.9753 - loss: 0.0868 - val_acc: 0.9886 - val_loss: 0.0403
## Epoch 7/10
## 399/399 - 3s - 7ms/step - acc: 0.9761 - loss: 0.0790 - val_acc: 0.9890 - val_loss: 0.0408
## Epoch 8/10
## 399/399 - 3s - 7ms/step - acc: 0.9794 - loss: 0.0685 - val_acc: 0.9878 - val_loss: 0.0443During training, we were saving a model at the end of each epoch. You can also save the model in its latest state like this:
And reload it like this:
Next, you can query predictions of class probabilities with
predict():
## 313/313 - 1s - 2ms/step## [1] 10000    10That’s it for the basics!
Keras enables you to write custom Layers, Models, Metrics, Losses, and Optimizers that work across TensorFlow, JAX, and PyTorch with the same codebase. Let’s take a look at custom layers first.
The op_ namespace contains:
op_stack or
op_matmul.op_conv or
op_binary_crossentropy.Let’s make a custom Dense layer that works with all
backends:
layer_my_dense <- Layer(
  classname = "MyDense",
  initialize = function(units, activation = NULL, name = NULL, ...) {
    super$initialize(name = name, ...)
    self$units <- units
    self$activation <- activation
  },
  build = function(input_shape) {
    input_dim <- tail(input_shape, 1)
    self$w <- self$add_weight(
      shape = shape(input_dim, self$units),
      initializer = initializer_glorot_normal(),
      name = "kernel",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(self$units),
      initializer = initializer_zeros(),
      name = "bias",
      trainable = TRUE
    )
  },
  call = function(inputs) {
    # Use Keras ops to create backend-agnostic layers/metrics/etc.
    x <- op_matmul(inputs, self$w) + self$b
    if (!is.null(self$activation))
      x <- self$activation(x)
    x
  }
)Next, let’s make a custom Dropout layer that relies on
the random_* namespace:
layer_my_dropout <- Layer(
  "MyDropout",
  initialize = function(rate, name = NULL, seed = NULL, ...) {
    super$initialize(name = name)
    self$rate <- rate
    # Use seed_generator for managing RNG state.
    # It is a state element and its seed variable is
    # tracked as part of `layer$variables`.
    self$seed_generator <- random_seed_generator(seed)
  },
  call = function(inputs) {
    # Use `keras3::random_*` for random ops.
    random_dropout(inputs, self$rate, seed = self$seed_generator)
  }
)Next, let’s write a custom subclassed model that uses our two custom layers:
MyModel <- Model(
  "MyModel",
  initialize = function(num_classes, ...) {
    super$initialize(...)
    self$conv_base <-
      keras_model_sequential() |>
      layer_conv_2d(64, kernel_size = c(3, 3), activation = "relu") |>
      layer_conv_2d(64, kernel_size = c(3, 3), activation = "relu") |>
      layer_max_pooling_2d(pool_size = c(2, 2)) |>
      layer_conv_2d(128, kernel_size = c(3, 3), activation = "relu") |>
      layer_conv_2d(128, kernel_size = c(3, 3), activation = "relu") |>
      layer_global_average_pooling_2d()
    self$dp <- layer_my_dropout(rate = 0.5)
    self$dense <- layer_my_dense(units = num_classes,
                                 activation = activation_softmax)
  },
  call = function(inputs) {
    inputs |>
      self$conv_base() |>
      self$dp() |>
      self$dense()
  }
)Let’s compile it and fit it:
model <- MyModel(num_classes = 10)
model |> compile(
  loss = loss_sparse_categorical_crossentropy(),
  optimizer = optimizer_adam(learning_rate = 1e-3),
  metrics = list(
    metric_sparse_categorical_accuracy(name = "acc")
  )
)
model |> fit(
  x_train, y_train,
  batch_size = batch_size,
  epochs = 1, # For speed
  validation_split = 0.15
)## 399/399 - 7s - 17ms/step - acc: 0.7350 - loss: 0.7737 - val_acc: 0.9276 - val_loss: 0.2389All Keras models can be trained and evaluated on a wide variety of data sources, independently of the backend you’re using. This includes:
tf_dataset objectsDataLoader objectsPyDataset objectsThey all work whether you’re using TensorFlow, JAX, or PyTorch as your Keras backend.
Let’s try this out with tf_dataset:
library(tfdatasets, exclude = "shape")
train_dataset <- list(x_train, y_train) |>
  tensor_slices_dataset() |>
  dataset_batch(batch_size) |>
  dataset_prefetch(buffer_size = tf$data$AUTOTUNE)
test_dataset <- list(x_test, y_test) |>
  tensor_slices_dataset() |>
  dataset_batch(batch_size) |>
  dataset_prefetch(buffer_size = tf$data$AUTOTUNE)
model <- MyModel(num_classes = 10)
model |> compile(
  loss = loss_sparse_categorical_crossentropy(),
  optimizer = optimizer_adam(learning_rate = 1e-3),
  metrics = list(
    metric_sparse_categorical_accuracy(name = "acc")
  )
)
model |> fit(train_dataset, epochs = 1, validation_data = test_dataset)## 469/469 - 8s - 17ms/step - acc: 0.7492 - loss: 0.7477 - val_acc: 0.9125 - val_loss: 0.2952This concludes our short overview of the new multi-backend capabilities of Keras 3. Next, you can learn about:
fit()Want to implement a non-standard training algorithm yourself but
still want to benefit from the power and usability of
fit()? It’s easy to customize fit() to support
arbitrary use cases: