Building Better Deep Learning Models

Keras, an intuitive deep learning API, is layered over TensorFlow to simplify model construction using declarative syntax. TensorFlow primarily focuses on tensor computations, serving as a machine learning framework.

In its initial stages, TensorFlow facilitated deep learning calculations, yet models required constructing layer sequences via function calls. Keras greatly simplified this by offering a high-level API for model definition and data fitting.

Originally platform-agnostic, Keras was first aligned with early computation platforms such as Theano. However, starting from TensorFlow 2.0, Keras exclusively operates within TensorFlow 2. This integration shields users from GPU and TPU intricacies, allowing for effortless neural network assembly akin to Lego blocks. These strides ushered in the current AI era.

Need for Keras Function API

While embracing a declarative approach to crafting deep learning models is a positive stride, crafting a truly effective model for real-world challenges requires more. A top prerequisite is the capability to construct pre-tested, readily usable modules, facilitating modularity and reusability within intricate models. The functional API within Keras mirrors assembling an aircraft from Lego blocks, allowing for the construction of intricate deep learning solutions.

The tensor shape is the second dimension to consider. Keras’ functional API accommodates input shapes as (None, Dimensions). For instance, a shape (None, number-of-channels = 3) enables Keras to adapt to inputs of varying sizes, given they are structured as three channels. This adaptability proves invaluable during feature engineering, as input data shapes evolve. These dual features establish Keras’ functional API as the preferred choice for developers, including High Plain’s adept team of machine learning engineers.

Need of Declarative API for NN Development

Most straightforward and small Neural Networks (NN) have the following structure.

Keras makes layer definition, selection of loss function and optimizer as well as a training loop that takes all training data set in batches, calculating loss based on predicted output, and re-adjusting weights to minimize error for next. Batch very easy

Following is an example of a Keras-based digits classification program that takes the famous  MNIST dataset of a 28 x 28-pixel image of a number and classifies it as a number between 0 to 9. MNIST dataset of about 60000 training and 10,000 test images 

(train_images, train_labels), _ = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255

model = keras.Sequential([
    layers.Dense(512, activation="relu"),
    layers.Dense(10, activation="softmax")
              metrics=["accuracy"]), train_labels,

This sequential model takes a tensor of 60000 x 768 float32 values, and the first dense layer outputs 512 x 768, 2nd dense layer outputs a 10 x 1 tensor with one bit turned on between 0 -9 to provide the probability of the predicted number between 0 to 9. Keras makes it possible to define layers stack and compile the model by specifying the desired loss function, optimizer and metrics through which the user tracks model accuracy as it undergoes training.

Complex Models with Keras Functional API

While keras.Sequential() lets stack layers and quickly create a deep learning model; most production-grade and real-world models are far more complex, with multiple inputs and outputs. They look more like a graph connecting multiple vertices rather than stacked layers.

Keras offer functional API to build the models like Lego building blocks. Following is one hypothetical model that predicts which department a ticket belongs to and what is the ticket’s priority. The model inputs are three separate sources providing a title, body text, and tags for tickets from three different sources.

Following diagram shows what the model would look like.

The following code snippet shows the functional API to realize this model.

vocabulary_size = 10000
num_tags = 100
num_departments = 4

ticket_title = keras.Input(shape = (vocabulary_size,) , name = "title")
ticket_body  = keras.Input(shape = (vocabulary_size,) , name = "body")
ticket_tags  = keras.Input(shape = (num_tags,)        , name = "tags")
inputs = [ticket_title , ticket_body , ticket_tags]

features = layers.concatenate([ticket_title , ticket_body , ticket_tags])
features = layers.Dense(64 , activation= ACTIVATION_RELU) (features)

priority = layers.Dense(1, activation= ACTIVATION_SIGMOID , name="priority") (features)
department = layers.Dense(num_departments , activation= ACTIVATION_SOFTMAX, name="department") (features)

outputs= [priority, department]

model = keras.Model(inputs=inputs , outputs=outputs)

The model is defined with shape, e.g. (none, 10000 ), which means the model can work with any batch size. This enables easy feature engineering as the model can use features from any other network.

The model lets name any layer, and an easy symbolic linking makes model creation easier.

For example, the priority layer takes input from the features layer.

priority = layers.Dense(1, activation= ACTIVATION_SIGMOID , name="priority") (features)

If the shape of the features layer changes, the priority layer shape is dynamically adjusted.

Beyond Functional API

Keras provides an easy way to do a PyTorch-like programming model for advanced Keras users.Keras user can extend Keras.Model class and fully customize the training loop with forward steps and custom metrics. By simply extending the model base class, e.g.

class CustomerTicketModel(keras.Model):

All code examples are adapted from an excellent. Deep learning book “Deep Learning with Python 2nd edition by Francois  Cholettt, also the author of Keras

About the author

Ajmal Mahmood is the Chief Architect for High Plains Computing (HPC). HPC provides cloud DevOps and MLOps services and helps roll out ML models to production using AWS cloud and Kubernetes.

Social Share :

Introducing Amazon Q

Overview Amazon Q is a new-gen AI solution that provides insights into enterprise data stores.…

Python Performance improvements

Python is a widely used programming language with a diverse range of libraries and frameworks,…

What is Retrieval Augmented Generation

What is Retrieval Augmented Generation Introduction Retrieval-augmented generation (RAG) is a cutting-edge technique that combines…

Ready to make your business more efficient?