1. Why Deep Learning in the Browser?

Running models on the client side opens up unique possibilities that server-side Python can't easily match:

  • Privacy: Data never leaves the user's device. Perfect for health or sensitive data.
  • Low Latency: No server round-trips. Interactions are instantaneous.
  • Device Access: Direct access to sensors like cameras, microphones, and accelerometers.
  • Cost: Compute happens on the user's GPU, saving server costs.

2. Core Concepts in TensorFlow.js

Key differences from Python's TensorFlow/Keras:

  • Tensors: The core data structure. You must manage memory manually using tf.tidy() or tensor.dispose() to prevent WebGL memory leaks.
  • Async/Await: Most operations (training, data loading, prediction) are asynchronous to avoid blocking the UI thread.
  • Layers API: Modeled after Keras, making it familiar for Python developers.

3. Setup

You can include TensorFlow.js via a CDN script tag for simple HTML pages, or install it via NPM for build pipelines.

<!-- CDN -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>

<!-- NPM -->
// npm install @tensorflow/tfjs
import * as tf from '@tensorflow/tfjs';

4. Creating and Training a Model

Here is a complete example of defining a simple linear regression model, generating synthetic data, and training it in the browser.

async function run() {
  // 1. Define the model (Equivalent to Keras Sequential)
  const model = tf.sequential();
  model.add(tf.layers.dense({units: 1, inputShape: [1]}));

  // 2. Compile the model
  model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});

  // 3. Generate synthetic data
  const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
  const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);

  // 4. Train the model (await is crucial here!)
  console.log('Training...');
  await model.fit(xs, ys, {epochs: 250});

  // 5. Predict
  const output = model.predict(tf.tensor2d([10], [1, 1]));
  output.print(); // Should print a value close to 19

  // Cleanup memory
  xs.dispose();
  ys.dispose();
  output.dispose();
}

run();

5. Loading Pre-trained Models

One of the most common use cases is loading a model trained in Python (saved as a JSON file) and using it for inference in the browser.

// Converting Python Keras model:
// tensorflowjs_converter --input_format keras model.h5 target_dir

async function loadAndPredict() {
    const model = await tf.loadLayersModel('path/to/model.json');

    const input = tf.zeros([1, 224, 224, 3]);
    const prediction = model.predict(input);

    prediction.print();
}

6. Further Reading

Performance Tip: TensorFlow.js utilizes WebGL to accelerate computations on the GPU. Always ensure your tensors are disposed of when no longer needed to avoid memory leaks that can crash the browser tab.