1. Why Deep Learning in the Browser?
Running models on the client side opens up unique possibilities that server-side Python can't easily match:
- Privacy: Data never leaves the user's device. Perfect for health or sensitive data.
- Low Latency: No server round-trips. Interactions are instantaneous.
- Device Access: Direct access to sensors like cameras, microphones, and accelerometers.
- Cost: Compute happens on the user's GPU, saving server costs.
2. Core Concepts in TensorFlow.js
Key differences from Python's TensorFlow/Keras:
- Tensors: The core data structure. You must manage memory manually using
tf.tidy()ortensor.dispose()to prevent WebGL memory leaks. - Async/Await: Most operations (training, data loading, prediction) are asynchronous to avoid blocking the UI thread.
- Layers API: Modeled after Keras, making it familiar for Python developers.
3. Setup
You can include TensorFlow.js via a CDN script tag for simple HTML pages, or install it via NPM for build pipelines.
<!-- CDN -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- NPM -->
// npm install @tensorflow/tfjs
import * as tf from '@tensorflow/tfjs';
4. Creating and Training a Model
Here is a complete example of defining a simple linear regression model, generating synthetic data, and training it in the browser.
async function run() {
// 1. Define the model (Equivalent to Keras Sequential)
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
// 2. Compile the model
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// 3. Generate synthetic data
const xs = tf.tensor2d([-1, 0, 1, 2, 3, 4], [6, 1]);
const ys = tf.tensor2d([-3, -1, 1, 3, 5, 7], [6, 1]);
// 4. Train the model (await is crucial here!)
console.log('Training...');
await model.fit(xs, ys, {epochs: 250});
// 5. Predict
const output = model.predict(tf.tensor2d([10], [1, 1]));
output.print(); // Should print a value close to 19
// Cleanup memory
xs.dispose();
ys.dispose();
output.dispose();
}
run();
5. Loading Pre-trained Models
One of the most common use cases is loading a model trained in Python (saved as a JSON file) and using it for inference in the browser.
// Converting Python Keras model:
// tensorflowjs_converter --input_format keras model.h5 target_dir
async function loadAndPredict() {
const model = await tf.loadLayersModel('path/to/model.json');
const input = tf.zeros([1, 224, 224, 3]);
const prediction = model.predict(input);
prediction.print();
}