This guide will cover everything you need to know to build your own subclassed layers and models. In particular, you’ll learn about the following features:
Layer classadd_weight() methodbuild() methodadd_loss() methodtraining argument in call()mask argument in call()Let’s dive in.
Layer class: the combination of state (weights) and
some computationOne of the central abstractions in Keras is the Layer
class. A layer encapsulates both a state (the layer’s “weights”) and a
transformation from inputs to outputs (a “call”, the layer’s forward
pass).
Here’s a densely-connected layer. It has two state variables: the
variables w and b.
layer_linear <- Layer("Linear",
  initialize = function(units = 32, input_dim = 32, ...) {
    super$initialize(...)
    self$w <- self$add_weight(
      shape = shape(input_dim, units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(units),
      initializer = "zeros",
      trainable = TRUE
    )
  },
  call = function(inputs) {
    op_matmul(inputs, self$w) + self$b
  }
)You would use a layer by calling it on some tensor input(s), much like an R function.
x <- op_ones(c(2, 2))
linear_layer <- layer_linear(units = 4, input_dim = 2)
y <- linear_layer(x)
print(y)## tf.Tensor(
## [[0.02153057 0.15450525 0.0205495  0.04493225]
##  [0.02153057 0.15450525 0.0205495  0.04493225]], shape=(2, 4), dtype=float32)Note that the weights w and b are
automatically tracked by the layer upon being set as layer
attributes:
## [[1]]
## <Variable path=linear/variable, shape=(2, 4), dtype=float32, value=[[-0.06251299  0.05335509  0.01485647 -0.00985784]
##  [ 0.08404355  0.10115016  0.00569303  0.05479009]]>
##
## [[2]]
## <Variable path=linear/variable_1, shape=(4), dtype=float32, value=[0. 0. 0. 0.]>Besides trainable weights, you can add non-trainable weights to a layer as well. Such weights are meant not to be taken into account during backpropagation, when you are training the layer.
Here’s how to add and use a non-trainable weight:
layer_compute_sum <- Layer(
  "ComputeSum",
  initialize = function(input_dim) {
    super$initialize()
    self$total <- self$add_weight(
      initializer = "zeros",
      shape = shape(input_dim),
      trainable = FALSE
    )
  },
  call = function(inputs) {
    self$total$assign_add(op_sum(inputs, axis = 1))
    self$total
  }
)
x <- op_ones(c(2, 2))
my_sum <- layer_compute_sum(input_dim = 2)
y <- my_sum(x)
print(as.array(y))## [1] 2 2## [1] 4 4It’s part of layer$weights, but it gets categorized as a
non-trainable weight:
## weights: 1## non-trainable weights: 1# It's not included in the trainable weights:
cat("trainable_weights:", length(my_sum$trainable_weights))## trainable_weights: 0Our Linear layer above took an input_dim
argument that was used to compute the shape of the weights
w and b in initialize():
layer_linear <- Layer("Linear",
  initialize = function(units = 32, input_dim = 32, ...) {
    super$initialize(...)
    self$w <- self$add_weight(
      shape = shape(input_dim, units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(units),
      initializer = "zeros",
      trainable = TRUE
    )
  },
  call = function(inputs) {
    op_matmul(inputs, self$w) + self$b
  }
)In many cases, you may not know in advance the size of your inputs, and you would like to lazily create weights when that value becomes known, some time after instantiating the layer.
In the Keras API, we recommend creating layer weights in the
build(self, inputs_shape) method of your layer. Like
this:
layer_linear <- Layer(
  "Linear",
  initialize = function(units = 32, ...) {
    self$units <- as.integer(units)
    super$initialize(...)
  },
  build = function(input_shape) {
    self$w <- self$add_weight(
      shape = shape(tail(input_shape, 1), self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(self$units),
      initializer = "zeros",
      trainable = TRUE
    )
  },
  call = function(inputs) {
    op_matmul(inputs, self$w) + self$b
  }
)The call() method of your layer will automatically run
build the first time it is called. You now have a layer that’s lazy and
thus easier to use:
# At instantiation, we don't know on what inputs this is going to get called
linear_layer <- layer_linear(units = 32)
# The layer's weights are created dynamically the first time the layer is called
y <- linear_layer(x)Implementing build() separately as shown above nicely
separates creating weights only once from using weights in every
call.
If you assign a Layer instance as an attribute of another Layer, the outer layer will start tracking the weights created by the inner layer.
We recommend creating such sublayers in the initialize()
method and leave it to the first call() to trigger building
their weights.
MLPBlock <- Layer(
  "MLPBlock",
  initialize = function() {
    super$initialize()
    self$linear_1 <- layer_linear(units = 32)
    self$linear_2 <- layer_linear(units = 32)
    self$linear_3 <- layer_linear(units = 1)
  },
  call = function(inputs) {
    inputs |>
      self$linear_1() |>
      activation_relu() |>
      self$linear_2() |>
      activation_relu() |>
      self$linear_3()
  }
)
mlp <- MLPBlock()
# The first call to the `mlp` will create the weights
y <- mlp(op_ones(shape = c(3, 64)))
cat("weights:", length(mlp$weights), "\n")## weights: 6## trainable weights: 6As long as a layer only uses APIs from the ops namespace
(ie. using functions starting with op_), (or other Keras
namespaces such as activations_*, random_*, or
layer_*), then it can be used with any backend –
TensorFlow, JAX, or PyTorch.
All layers you’ve seen so far in this guide work with all Keras backends.
The ops namespace gives you access to:
op_matmul, op_sum,
op_reshape, op_stack, etc.op_softmax,
op_conv, op_binary_crossentropy,
op_relu, etc.You can also use backend-native APIs in your layers (such as
tf$nn functions), but if you do this, then your layer will
only be usable with the backend in question. For instance, you could
write the following JAX-specific layer using jax$numpy:
# keras3::install_keras(backend = c("jax"))
jax <- reticulate::import("jax")
Linear <- new_layer_class(
  ...
  call = function(inputs) {
    jax$numpy$matmul(inputs, self$w) + self$b
  }
)This would be the equivalent TensorFlow-specific layer:
library(tensorflow)
Linear <- new_layer_class(
  ...
  call = function(inputs) {
    tf$matmul(inputs, self$w) + self$b
  }
)And this would be the equivalent PyTorch-specific layer:
torch <- reticulate::import("torch")
Linear <- new_layer_class(
  ...
  call = function(inputs) {
    torch$matmul(inputs, self$w) + self$b
  }
)Because cross-backend compatibility is a tremendously useful property, we strongly recommend that you seek to always make your layers backend-agnostic by leveraging only Keras APIs.
add_loss() methodWhen writing the call() method of a layer, you can
create loss tensors that you will want to use later, when writing your
training loop. This is doable by calling
self$add_loss(value):
# A layer that creates an activity regularization loss
layer_activity_regularization <- Layer(
  "ActivityRegularizationLayer",
  initialize = function(rate = 1e-2) {
    self$rate <- as.numeric(rate)
    super$initialize()
  },
  call = function(inputs) {
    self$add_loss(self$rate * op_mean(inputs))
    inputs
  }
)These losses (including those created by any inner layer) can be
retrieved via layer$losses. This property is reset at the
start of every call to the top-level layer, so that
layer$losses always contains the loss values created during
the last forward pass.
layer_outer <- Layer(
  "OuterLayer",
  initialize = function() {
    super$initialize()
    self$activity_reg <- layer_activity_regularization(rate = 1e-2)
  },
  call = function(inputs) {
    self$activity_reg(inputs)
    inputs
  }
)
layer <- layer_outer()
# No losses yet since the layer has never been called
cat("losses:", length(layer$losses), "\n")## losses: 0x <- layer(op_zeros(c(1, 1)))
# We created one loss value
cat("losses:", length(layer$losses), "\n")## losses: 1# `layer$losses` gets reset at the start of each call
x <- layer(op_zeros(c(1, 1)))
# This is the loss created during the call above
cat("losses:", length(layer$losses), "\n")## losses: 1In addition, the loss property also contains
regularization losses created for the weights of any inner layer:
layer_outer_with_kernel_regularizer <- Layer(
  "OuterLayerWithKernelRegularizer",
  initialize = function() {
    super$initialize()
    self$dense <- layer_dense(units = 32,
                              kernel_regularizer = regularizer_l2(1e-3))
  },
  call = function(inputs) {
    self$dense(inputs)
  }
)
layer <- layer_outer_with_kernel_regularizer()
x <- layer(op_zeros(c(1, 1)))
# This is `1e-3 * sum(layer$dense$kernel ** 2)`,
# created by the `kernel_regularizer` above.
print(layer$losses)## [[1]]
## tf.Tensor(0.002025157, shape=(), dtype=float32)These losses are meant to be taken into account when writing custom training loops.
They also work seamlessly with fit() (they get
automatically summed and added to the main loss, if any):
inputs <- keras_input(shape = 3)
outputs <- inputs |> layer_activity_regularization()
model <- keras_model(inputs, outputs)
# If there is a loss passed in `compile`, the regularization
# losses get added to it
model |> compile(optimizer = "adam", loss = "mse")
model |> fit(random_normal(c(2, 3)), random_normal(c(2, 3)), epochs = 1)## 1/1 - 0s - 161ms/step - loss: 1.9081# It's also possible not to pass any loss in `compile`,
# since the model already has a loss to minimize, via the `add_loss`
# call during the forward pass!
model |> compile(optimizer = "adam")
model |> fit(random_normal(c(2, 3)), random_normal(c(2, 3)), epochs = 1)## 1/1 - 0s - 139ms/step - loss: 1.6613If you need your custom layers to be serializable as part of a Functional model, you can optionally
implement a get_config() method:
layer_linear <- Layer(
  "Linear",
  initialize = function(units = 32) {
    self$units <- as.integer(units)
    super$initialize()
  },
  build = function(input_shape) {
    self$w <- self$add_weight(
      shape = shape(tail(input_shape, 1), self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(self$units),
      initializer = "zeros",
      trainable = TRUE
    )
  },
  call = function(inputs) {
    op_matmul(inputs, self$w) + self$b
  },
  get_config = function() {
    list(units = self$units)
  }
)
# Now you can recreate the layer from its config:
layer <- layer_linear(units = 64)
config <- get_config(layer)
str(config)## List of 1
##  $ units: int 64
##  - attr(*, "__class__")=<class '<r-globalenv>.Linear'>Note that the initialize() method of the base
Layer class takes some keyword arguments, in particular a
name and a dtype. It’s good practice to pass
these arguments to the parent class in initialize() and to
include them in the layer config:
Linear <- new_layer_class(
  "Linear",
  initialize = function(units = 32, ...) {
    self$units <- as.integer(units)
    super$initialize(...)
  },
  build = function(input_shape) {
    self$w <- self$add_weight(
      shape = shape(tail(input_shape, 1), self$units),
      initializer = "random_normal",
      trainable = TRUE
    )
    self$b <- self$add_weight(
      shape = shape(self$units),
      initializer = "zeros",
      trainable = TRUE
    )
  },
  call = function(inputs) {
    op_matmul(inputs, self$w) + self$b
  },
  get_config = function() {
    list(units = self$units)
  }
)
layer <- Linear(units = 64)
config <- get_config(layer)
str(config)## List of 1
##  $ units: int 64
##  - attr(*, "__class__")=<class '<r-globalenv>.Linear'>If you need more flexibility when deserializing the layer from its
config, you can also override the from_config() class
method. This is the base implementation of
from_config():
Layer(
  ...,
  from_config = function(config) {
    # calling `__class__`() creates a new instance and calls initialize()
    do.call(`__class__`, config)
  }
)To learn more about serialization and saving, see the complete guide to saving and serializing models.
training argument in the call()
methodSome layers, in particular the BatchNormalization layer
and the Dropout layer, have different behaviors during
training and inference. For such layers, it is standard practice to
expose a training (boolean) argument in the
call() method.
By exposing this argument in call(), you enable the
built-in training and evaluation loops (e.g. fit()) to
correctly use the layer in training and inference.
layer_custom_dropout <- Layer(
  "CustomDropout",
  initialize = function(rate, ...) {
    super$initialize(...)
    self$rate <- rate
    self$seed_generator <- random_seed_generator(1337)
  },
  call = function(inputs, training = NULL) {
    if (isTRUE(training))
      return(random_dropout(inputs, rate = self$rate,
                            seed = self.seed_generator))
    inputs
  }
)mask argument in the call()
methodThe other privileged argument supported by call() is the
mask argument.
You will find it in all Keras RNN layers. A mask is a boolean tensor (one boolean value per timestep in the input) used to skip certain input timesteps when processing timeseries data.
Keras will automatically pass the correct mask argument
to call() for layers that support it, when a mask is
generated by a prior layer. Mask-generating layers are the
Embedding layer configured with
mask_zero = TRUE, and the Masking layer.
Model classIn general, you will use the Layer class to define inner
computation blocks, and will use the Model class to define
the outer model – the object you will train.
For instance, in a ResNet50 model, you would have several ResNet
blocks subclassing Layer, and a single Model
encompassing the entire ResNet50 network.
The Model class has the same API as Layer,
with the following differences:
fit(), evaluate(),
predict()).model$layers property.save(),
save_weights()…)Effectively, the Layer class corresponds to what we
refer to in the literature as a “layer” (as in “convolution layer” or
“recurrent layer”) or as a “block” (as in “ResNet block” or “Inception
block”).
Meanwhile, the Model class corresponds to what is
referred to in the literature as a “model” (as in “deep learning model”)
or as a “network” (as in “deep neural network”).
So if you’re wondering, “should I use the Layer class or
the Model class?”, ask yourself: will I need to call
fit() on it? Will I need to call save() on it?
If so, go with Model. If not (either because your class is
just a block in a bigger system, or because you are writing training
& saving code yourself), use Layer.
For instance, we could take our mini-resnet example above, and use it
to build a Model that we could train with
fit(), and that we could save with
save_weights():
ResNet <- Model(
  "ResNet",
  initialize = function(num_classes = 1000, ...) {
    super$initialize(...)
    self$block_1 <- layer_resnet_block()
    self$block_2 <- layer_resnet_block()
    self$global_pool <- layer_global_average_pooling_2d()
    self$classifier <- layer_dense(num_classes)
  },
  call = function(inputs) {
    inputs |>
      self$block_1() |>
      self$block_2() |>
      self$global_pool() |>
      self$classifier()
  }
)
resnet <- ResNet()
dataset <- ...
resnet |> fit(dataset, epochs=10)
resnet |> save_model("filepath.keras")Here’s what you’ve learned so far:
Layer encapsulate a state (created in
initialize() or build()) and some computation
(defined in call()).jax$numpy,
torch$nn or tf$nn), but then your layer will
only be usable with that specific backend.add_loss().Model. A Model is just like a
Layer, but with added training and serialization
utilities.Let’s put all of these things together into an end-to-end example: we’re going to implement a Variational AutoEncoder (VAE) in a backend-agnostic fashion – so that it runs the same with TensorFlow, JAX, and PyTorch. We’ll train it on MNIST digits.
Our VAE will be a subclass of Model, built as a nested
composition of layers that subclass Layer. It will feature
a regularization loss (KL divergence).
layer_sampling <- Layer(
  "Sampling",
  initialize = function(...) {
    super$initialize(...)
    self$seed_generator <- random_seed_generator(1337)
  },
  call = function(inputs) {
    c(z_mean, z_log_var) %<-% inputs
    batch <- op_shape(z_mean)[[1]]
    dim <- op_shape(z_mean)[[2]]
    epsilon <- random_normal(shape = c(batch, dim),
                             seed=self$seed_generator)
    z_mean + op_exp(0.5 * z_log_var) * epsilon
  }
)
# Maps MNIST digits to a triplet (z_mean, z_log_var, z).
layer_encoder <- Layer(
  "Encoder",
  initialize = function(latent_dim = 32, intermediate_dim = 64, ...) {
    super$initialize(...)
    self$dense_proj <-
      layer_dense(units = intermediate_dim,  activation = "relu")
    self$dense_mean <- layer_dense(units = latent_dim)
    self$dense_log_var <- layer_dense(units = latent_dim)
    self$sampling <- layer_sampling()
  },
  call = function(inputs) {
    x <- self$dense_proj(inputs)
    z_mean <- self$dense_mean(x)
    z_log_var <- self$dense_log_var(x)
    z <- self$sampling(list(z_mean, z_log_var))
    list(z_mean, z_log_var, z)
  }
)
# Converts z, the encoded digit vector, back into a readable digit.
layer_decoder <- Layer(
  "Decoder",
  initialize = function(original_dim, intermediate_dim = 64, ...) {
    super$initialize(...)
    self$dense_proj <-
      layer_dense(units = intermediate_dim, activation = "relu")
    self$dense_output <-
      layer_dense(units = original_dim, activation = "sigmoid")
  },
  call = function(inputs) {
    x <- self$dense_proj(inputs)
    self$dense_output(x)
  }
)
# Combines the encoder and decoder into an end-to-end model for training.
VariationalAutoEncoder <- Model(
  "VariationalAutoEncoder",
  initialize = function(original_dim, intermediate_dim = 64, latent_dim = 32,
                        name = "autoencoder", ...) {
    super$initialize(name = name, ...)
    self$original_dim <- original_dim
    self$encoder <- layer_encoder(latent_dim = latent_dim,
                            intermediate_dim = intermediate_dim)
    self$decoder <- layer_decoder(original_dim = original_dim,
                            intermediate_dim = intermediate_dim)
  },
  call = function(inputs) {
    c(z_mean, z_log_var, z) %<-% self$encoder(inputs)
    reconstructed <- self$decoder(z)
    # Add KL divergence regularization loss.
    kl_loss <- -0.5 * op_mean(z_log_var - op_square(z_mean) - op_exp(z_log_var) + 1)
    self$add_loss(kl_loss)
    reconstructed
  }
)Let’s train it on MNIST using the fit() API:
c(c(x_train, .), .) %<-% dataset_mnist()
x_train <- x_train |>
  op_reshape(c(60000, 784)) |>
  op_cast("float32") |>
  op_divide(255)
original_dim <- 784
vae <- VariationalAutoEncoder(
  original_dim = 784,
  intermediate_dim = 64,
  latent_dim = 32
)
optimizer <- optimizer_adam(learning_rate = 1e-3)
vae |> compile(optimizer, loss = loss_mean_squared_error())
vae |> fit(x_train, x_train, epochs = 2, batch_size = 64)## Epoch 1/2
## 938/938 - 4s - 5ms/step - loss: 0.0748
## Epoch 2/2
## 938/938 - 1s - 1ms/step - loss: 0.0676