fit() with TensorFlowWhen you’re doing supervised learning, you can use fit()
and everything works smoothly.
When you need to take control of every little detail, you can write your own training loop entirely from scratch.
But what if you need a custom training algorithm, but you still want
to benefit from the convenient features of fit(), such as
callbacks, built-in distribution support, or step fusing?
A core principle of Keras is progressive disclosure of complexity. You should always be able to get into lower-level workflows in a gradual way. You shouldn’t fall off a cliff if the high-level functionality doesn’t exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience.
When you need to customize what fit() does, you should
override the training step function of the Model
class. This is the function that is called by
fit() for every batch of data. You will then be able to
call fit() as usual – and it will be running your own
learning algorithm.
Note that this pattern does not prevent you from building models with
the Functional API. You can do this whether you’re building
Sequential models, Functional API models, or subclassed
models.
Let’s see how that works.
Let’s start from a simple example:
Model.train_step(self, data).The input argument data is what gets passed to fit as
training data:
fit(x, y, ...), then
data will be the list (x, y)tf_dataset, by calling
fit(dataset, ...), then data will be what gets
yielded by dataset at each batch.In the body of the train_step() method, we implement a
regular training update, similar to what you are already familiar with.
Importantly, we compute the loss via
self.compute_loss(), which wraps the loss(es)
function(s) that were passed to compile().
Similarly, we call metric$update_state(y, y_pred) on
metrics from self$metrics, to update the state of the
metrics that were passed in compile(), and we query results
from self$metrics at the end to retrieve their current
value.
CustomModel <- new_model_class(
  "CustomModel",
  train_step = function(data) {
    c(x, y = NULL, sample_weight = NULL) %<-% data
    with(tf$GradientTape() %as% tape, {
      y_pred <- self(x, training = TRUE)
      loss <- self$compute_loss(y = y, y_pred = y_pred,
                                sample_weight = sample_weight)
    })
    # Compute gradients
    trainable_vars <- self$trainable_variables
    gradients <- tape$gradient(loss, trainable_vars)
    # Update weights
    self$optimizer$apply(gradients, trainable_vars)
    # Update metrics (includes the metric that tracks the loss)
    for (metric in self$metrics) {
      if (metric$name == "loss")
        metric$update_state(loss)
      else
        metric$update_state(y, y_pred)
    }
    # Return a dict mapping metric names to current value
    metrics <- lapply(self$metrics, function(m) m$result())
    metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name))
    metrics
  }
)Let’s try this out:
# Construct and compile an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, 1)
model <- CustomModel(inputs, outputs)
model |> compile(optimizer = "adam", loss = "mse", metrics = "mae")
# Just use `fit` as usual
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
model |> fit(x, y, epochs = 3)## Epoch 1/3
## 32/32 - 1s - 23ms/step - mae: 1.4339 - loss: 3.2271
## Epoch 2/3
## 32/32 - 0s - 2ms/step - mae: 1.3605 - loss: 2.9034
## Epoch 3/3
## 32/32 - 0s - 2ms/step - mae: 1.2960 - loss: 2.6272Naturally, you could just skip passing a loss function in
compile(), and instead do everything manually in
train_step. Likewise for metrics.
Here’s a lower-level example, that only uses compile()
to configure the optimizer:
Metric instances to track our loss
and a MAE score (in __init__()).train_step() that updates the
state of these metrics (by calling update_state() on them),
then query them (via result()) to return their current
average value, to be displayed by the progress bar and to be pass to any
callback.reset_states() on our
metrics between each epoch! Otherwise calling result()
would return an average since the start of training, whereas we usually
work with per-epoch averages. Thankfully, the framework can do that for
us: just list any metric you want to reset in the metrics
property of the model. The model will call reset_states()
on any object listed here at the beginning of each fit()
epoch or at the beginning of a call to evaluate().CustomModel <- new_model_class(
  "CustomModel",
  initialize = function(...) {
    super$initialize(...)
    self$loss_tracker <- metric_mean(name = "loss")
    self$mae_metric <- metric_mean_absolute_error(name = "mae")
    self$loss_fn <- loss_mean_squared_error()
  },
  train_step = function(data) {
    c(x, y = NULL, sample_weight = NULL) %<-% data
    with(tf$GradientTape() %as% tape, {
      y_pred <- self(x, training = TRUE)
      loss <- self$loss_fn(y, y_pred, sample_weight = sample_weight)
    })
    # Compute gradients
    trainable_vars <- self$trainable_variables
    gradients <- tape$gradient(loss, trainable_vars)
    # Update weights
    self$optimizer$apply(gradients, trainable_vars)
    # Compute our own metrics
    self$loss_tracker$update_state(loss)
    self$mae_metric$update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    list(
      loss = self$loss_tracker$result(),
      mae = self$mae_metric$result()
    )
  },
  metrics = mark_active(function() {
    # We list our `Metric` objects here so that `reset_states()` can be
    # called automatically at the start of each epoch
    # or at the start of `evaluate()`.
    list(self$loss_tracker, self$mae_metric)
  })
)
# Construct and compile an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, 1)
model <- CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model |> compile(optimizer = "adam")
# Just use `fit` as usual
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
model |> fit(x, y, epochs = 3)## Epoch 1/3
## 32/32 - 1s - 20ms/step - loss: 2.5170 - mae: 1.2923
## Epoch 2/3
## 32/32 - 0s - 2ms/step - loss: 2.2689 - mae: 1.2241
## Epoch 3/3
## 32/32 - 0s - 2ms/step - loss: 2.0578 - mae: 1.1633sample_weight &
class_weightYou may have noticed that our first basic example didn’t make any
mention of sample weighting. If you want to support the
fit() arguments sample_weight and
class_weight, you’d simply do the following:
sample_weight from the data
argumentcompute_loss & update_state
(of course, you could also just apply it manually if you don’t rely on
compile() for losses & metrics)CustomModel <- new_model_class(
  "CustomModel",
  train_step = function(data) {
    c(x, y = NULL, sample_weight = NULL) %<-% data
    with(tf$GradientTape() %as% tape, {
      y_pred <- self(x, training = TRUE)
      loss <- self$compute_loss(y = y, y_pred = y_pred,
                                sample_weight = sample_weight)
    })
    # Compute gradients
    trainable_vars <- self$trainable_variables
    gradients <- tape$gradient(loss, trainable_vars)
    # Update weights
    self$optimizer$apply_gradients(zip_lists(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    for (metric in self$metrics) {
      if (metric$name == "loss") {
        metric$update_state(loss)
      } else {
        metric$update_state(y, y_pred, sample_weight = sample_weight)
      }
    }
    # Return a dict mapping metric names to current value
    metrics <- lapply(self$metrics, function(m) m$result())
    metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name))
    metrics
  }
)
# Construct and compile an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, units = 1)
model <- CustomModel(inputs, outputs)
model |> compile(optimizer = "adam", loss = "mse", metrics = "mae")
# You can now use sample_weight argument
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
sw <- random_normal(c(1000, 1))
model |> fit(x, y, sample_weight = sw, epochs = 3)## Epoch 1/3
## 32/32 - 1s - 23ms/step - mae: 1.3434 - loss: 0.1681
## Epoch 2/3
## 32/32 - 0s - 2ms/step - mae: 1.3364 - loss: 0.1394
## Epoch 3/3
## 32/32 - 0s - 4ms/step - mae: 1.3286 - loss: 0.1148What if you want to do the same for calls to
model.evaluate()? Then you would override
test_step in exactly the same way. Here’s what it looks
like:
CustomModel <- new_model_class(
  "CustomModel",
  test_step = function(data) {
    # Unpack the data
    c(x, y = NULL, sw = NULL) %<-% data
    # Compute predictions
    y_pred = self(x, training = FALSE)
    # Updates the metrics tracking the loss
    self$compute_loss(y = y, y_pred = y_pred, sample_weight = sw)
    # Update the metrics.
    for (metric in self$metrics) {
      if (metric$name != "loss") {
        metric$update_state(y, y_pred, sample_weight = sw)
      }
    }
    # Return a dict mapping metric names to current value.
    # Note that it will include the loss (tracked in self.metrics).
    metrics <- lapply(self$metrics, function(m) m$result())
    metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name))
    metrics
  }
)
# Construct an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, 1)
model <- CustomModel(inputs, outputs)
model |> compile(loss = "mse", metrics = "mae")
# Evaluate with our custom test_step
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
model |> evaluate(x, y)## 32/32 - 0s - 10ms/step - mae: 1.3871 - loss: 0.0000e+00## $loss
## tf.Tensor(0.0, shape=(), dtype=float32)
##
## $compile_metrics
## $compile_metrics$mae
## tf.Tensor(1.3871489, shape=(), dtype=float32)Let’s walk through an end-to-end example that leverages everything you just learned.
Let’s consider:
# Create the discriminator
discriminator <-
  keras_model_sequential(name = "discriminator", input_shape = c(28, 28, 1)) |>
  layer_conv_2d(filters = 64, kernel_size = c(3, 3),
                strides = c(2, 2),  padding = "same") |>
  layer_activation_leaky_relu(negative_slope = 0.2) |>
  layer_conv_2d(filters = 128, kernel_size = c(3, 3),
                strides = c(2, 2), padding = "same") |>
  layer_activation_leaky_relu(negative_slope = 0.2) |>
  layer_global_max_pooling_2d() |>
  layer_dense(units = 1)
# Create the generator
latent_dim <- 128
generator <-
  keras_model_sequential(name = "generator", input_shape = latent_dim) |>
  layer_dense(7 * 7 * 128) |>
  layer_activation_leaky_relu(negative_slope = 0.2) |>
  layer_reshape(target_shape = c(7, 7, 128)) |>
  layer_conv_2d_transpose(filters = 128, kernel_size = c(4, 4),
                          strides = c(2, 2), padding = "same") |>
  layer_activation_leaky_relu(negative_slope = 0.2) |>
  layer_conv_2d_transpose(filters = 128, kernel_size = c(4, 4),
                          strides = c(2, 2), padding = "same") |>
  layer_activation_leaky_relu(negative_slope = 0.2) |>
  layer_conv_2d(filters = 1, kernel_size = c(7, 7), padding = "same",
                activation = "sigmoid")Here’s a feature-complete GAN class, overriding
compile() to use its own signature, and implementing the
entire GAN algorithm in 17 lines in train_step:
GAN <- Model(
  classname = "GAN",
  initialize = function(discriminator, generator, latent_dim, ...) {
    super$initialize(...)
    self$discriminator <- discriminator
    self$generator <- generator
    self$latent_dim <- as.integer(latent_dim)
    self$d_loss_tracker <- metric_mean(name = "d_loss")
    self$g_loss_tracker <- metric_mean(name = "g_loss")
  },
  compile = function(d_optimizer, g_optimizer, loss_fn, ...) {
    super$compile(...)
    self$d_optimizer <- d_optimizer
    self$g_optimizer <- g_optimizer
    self$loss_fn <- loss_fn
  },
  metrics = active_property(function() {
    list(self$d_loss_tracker, self$g_loss_tracker)
  }),
  train_step = function(real_images) {
    # Sample random points in the latent space
    batch_size <- shape(real_images)[[1]]
    random_latent_vectors <-
      tf$random$normal(shape(batch_size, self$latent_dim))
    # Decode them to fake images
    generated_images <- self$generator(random_latent_vectors)
    # Combine them with real images
    combined_images <- op_concatenate(list(generated_images,
                                           real_images))
    # Assemble labels discriminating real from fake images
    labels <- op_concatenate(list(op_ones(c(batch_size, 1)),
                                  op_zeros(c(batch_size, 1))))
    # Add random noise to the labels - important trick!
    labels %<>% `+`(tf$random$uniform(shape(.), maxval = 0.05))
    # Train the discriminator
    with(tf$GradientTape() %as% tape, {
      predictions <- self$discriminator(combined_images)
      d_loss <- self$loss_fn(labels, predictions)
    })
    grads <- tape$gradient(d_loss, self$discriminator$trainable_weights)
    self$d_optimizer$apply_gradients(
      zip_lists(grads, self$discriminator$trainable_weights))
    # Sample random points in the latent space
    random_latent_vectors <-
      tf$random$normal(shape(batch_size, self$latent_dim))
    # Assemble labels that say "all real images"
    misleading_labels <- op_zeros(c(batch_size, 1))
    # Train the generator (note that we should *not* update the weights
    # of the discriminator)!
    with(tf$GradientTape() %as% tape, {
      predictions <- self$discriminator(self$generator(random_latent_vectors))
      g_loss <- self$loss_fn(misleading_labels, predictions)
    })
    grads <- tape$gradient(g_loss, self$generator$trainable_weights)
    self$g_optimizer$apply_gradients(
      zip_lists(grads, self$generator$trainable_weights))
    list(d_loss = d_loss, g_loss = g_loss)
  }
)Let’s test-drive it:
batch_size <- 64
c(c(x_train, .), c(x_test, .)) %<-% dataset_mnist()
all_digits <- op_concatenate(list(x_train, x_test))
all_digits <- op_reshape(all_digits, c(-1, 28, 28, 1))
dataset <- all_digits |>
  tfdatasets::tensor_slices_dataset() |>
  tfdatasets::dataset_map(\(x) op_cast(x, "float32") / 255) |>
  tfdatasets::dataset_shuffle(buffer_size = 1024) |>
  tfdatasets::dataset_batch(batch_size = batch_size)
gan <- GAN(discriminator = discriminator,
           generator = generator,
           latent_dim = latent_dim)
gan |> compile(
  d_optimizer = optimizer_adam(learning_rate = 0.0003),
  g_optimizer = optimizer_adam(learning_rate = 0.0003),
  loss_fn = loss_binary_crossentropy(from_logits = TRUE)
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan |> fit(
  tfdatasets::dataset_take(dataset, 100),
  epochs = 1
)## 100/100 - 5s - 53ms/step - d_loss: 0.0000e+00 - g_loss: 0.0000e+00The ideas behind deep learning are simple, so why should their implementation be painful?