Keras Core: Keras AND PyTorch?! A first look at the new Keras (3.0)

TensorFlow vs PyTorch was one of THE debates of the last decade or so among deep learning practitioners. Personally, I ended up in the PyTorch camp, but whenever I talked to TensorFlow people, their main argument was, “But Keras is so convenient!”

PyTorch itself has had PyTorch Lightning for a few years now. Although it simplified a few things, it never reached the widespread use that Keras did for Tensorflow. However, PyTorch itself has become more and more popular over the years (as seen above in Google Trends).

That brings us to today and Keras Core, which is compatible with PyTorch and was announced on July 10th, 2023!

In Fall 2023, Keras Core will officially become Keras 3.0, giving us the best of both worlds. Let’s take a look:

Table of Contents

What is Keras Core?

Keras Core is a library by the Keras team that enables you to quickly build and train deep learning models.

Keras has always been about user experience (UX), quick prototyping, and user-friendliness. Starting with Theano in 2015, it adopted TensorFlow as an additional backend in 2016 and was officially integrated into TensorFlow as a high-level API a year later.

Keras Core is basically the same as Keras, with the main difference that it now supports TensorFlow AND PyTorch as backends. Oh, and JAX as well.

For now, it remains separate from the main Keras repository, but it will become Keras 3.0 in Fall 2023. Therefore, it is essentially a beta version right now.

How to install and configure Keras Core

Selecting PyTorch, Jax or TensorFlow as Keras backend

To select which backend will be used, use the environment variable “KERAS_BACKEND”. You can export it via a terminal command like this:

$ export KERAS_BACKEND="torch"

Or, you can use a .env file that you can load and use in your Python code with a library like python-dotenv. (I like this method specifically for virtual environments and working with VSCode Jupyter Notebooks).

Available options currently are "tensorflow", "jax", "torch".

There are a few other options for configuration, check the offical documentation.

Note that the backend needs to be configured before importing keras_core. (If you forget, just restart the kernel or script)

Installing and importing Keras Core

Installing is pretty simple with the following command for pip:

pip install keras-core

You can then import it in your Python code like this:

import keras_core as keras

Getting started with Keras MNIST

The Keras team provided a short tutorial implementing a model training MNIST (basically the “Hello World”-equivalent of deep learning), you can find it here in full.

But let’s look at the main parts together from a PyTorch perspective:

Importing and setup of Keras Core

from dotenv import load_dotenv

import keras_core as keras
import numpy as np

The dataset (MNIST of course)

Next we download the MNIST data using the convenient keras.datasets utility and then prepare the images to have the right format:

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print("y_train shape:", y_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

Side note: I will never understand why tutorials from the TensorFlow world scale images to [0, 1] and PyTorch tutorials do [-1, 1] centered at 0. Both methods seem to work on MNIST – maybe the Keras-author, Francois Chollet, just does not like negative numbers.

PyTorch user comment: The list of datasets included in keras datasets is not very long, so this will not bring much additional utility if you are coming from PyTorch. However, it is possible to use PyTorch Datasets and DataLoaders as input for Keras Core models, so we can continue to use all our beloved datasets if we want.

Creating the Keras Core model using the Sequential API

Keras has 3 different ways to create models:

# Model parameters
num_classes = 10
input_shape = (28, 28, 1)

model = keras.Sequential(
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        keras.layers.MaxPooling2D(pool_size=(2, 2)),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
        keras.layers.Conv2D(128, kernel_size=(3, 3), activation="relu"),
        keras.layers.Dense(num_classes, activation="softmax"),

PyTorch user comment:

  • It always has made more sense to me that TensorFlow and Keras explicitly define the input shape instead of it being implicit in the first layer. I’m glad we can use this syntax now as well.
  • PyTorch’s Linear layer is basically Keras’ Dense layer (if you set activation to None of course)
  • I see why they chose the sequential API for this tutorial – it’s definitely the most familiar way of defining a model for PyTorch users coming from nn.Sequential

Compiling the model (this is where they lost me)

Instead of passing the optimizer, learning rate scheduler, loss etc to the train function, Keras insists on “compiling” the model.

Pro: It’s a neat wrapping around all of the lose parts that are used to train the model.

Con: It doesn’t make any sense to me to “compile” in a language that’s interpreted. Unless they are really compiling underlying C-Code?


Side note/Full Disclosure: I have no idea what this compiling does under the hood, I’m merely commenting on the user experience.

Taking a closer look at all the metrics, losses and optimizers has to wait until another day for me.

Training the model (or fitting the model, as the Keras-ser like to say)

Okay, this is where I fall in love with the simplicity. Just look at that:

batch_size = 128
epochs = 20

callbacks = [
    keras.callbacks.EarlyStopping(monitor="val_loss", patience=2),
score = model.evaluate(x_test, y_test, verbose=0)

PyTorch user comment:

  • Coming from classical ML models it makes a lot of sense to call it “fit
  • You have to try this out yourself in a notebook! The visual updates during training are so sleek! No printing commands necessary, it’s all behind a nice wrapping function and looks prettier than my print() commands ever did:
  • A validation split is easily implemented with a single parameter!
  • I am bit nervous how much of this simplicity will be annoying once you want to change a single thing though… There is a more in-depth guide on this in the official documentation: Customizing what happens in fit() with PyTorch
  • The callbacks are still a bit confusing to me.

Add-on: Using Mps/Metal GPU on MacOS (or rather: not using MPS because it does not work)

This small tutorial did not tell me how I can make sure to use MPS (since PyTorch does support the M1 and M2 chips on Mac now) to utilitize my GPU on Mac. I also could not find another tutorial or explanation on the fly.

A quick search led me to this currently open GitHub issue from 2 weeks ago:

This seems to be another related open issue regarding device selection:

Maybe by the time you read this, you can find updates in these issues.

If you know how to solve this, please leave me a comment or message me on Instagram/Twitter!

Conclusion: Exciting development

In summary: I love the simplicity and I’m really thankful that the endless “I don’t want to use PyTorch because I like Keras” fights can finally end.

I never really got into PyTorch Lightning because it was still a lot of code and I didn’t have a long enough project to test it out properly, so Keras as a really short and to-the-point library is a nice change.

It will be interesting to see how the public perception will be and how many PyTorch users will actually use this in the next few months and after official release of Keras 3.0. So far I have not heard anything about this – but then again everyone is very occupied with LLMs right now…

Leave a Reply

Consent Management Platform by Real Cookie Banner