Home Industries Education Ivy Framework Agnostic Machine Learning Build, Transpile, and Benchmark Across All Major...

Ivy Framework Agnostic Machine Learning Build, Transpile, and Benchmark Across All Major Backends

0

This comprehensive guide delves into Ivy’s exceptional capability to harmonize machine learning development across multiple frameworks. We start by crafting a fully backend-agnostic neural network that operates flawlessly on NumPy, PyTorch, TensorFlow, and JAX. Following that, we explore code transpilation, unified APIs, and sophisticated functionalities such as Ivy Containers and graph tracing, all aimed at enhancing portability, efficiency, and backend independence in deep learning workflows. Throughout this journey, we observe how Ivy streamlines model construction, training, and evaluation without confining developers to a single ecosystem.

Building a Backend-Agnostic Neural Network with Ivy

We implement a straightforward neural network entirely using Ivy, demonstrating true framework neutrality. This model runs identically across popular backends-NumPy, PyTorch, TensorFlow, and JAX-yielding consistent accuracy and performance metrics. Ivy abstracts the underlying differences between these frameworks, allowing developers to focus on model logic rather than backend-specific details.

!pip install -q ivy tensorflow torch jax jaxlib

import ivy
import numpy as np
import time

print(f"Ivy version: {ivy.__version__}")

class IvyNeuralNetwork:
    """A minimalist neural network built with Ivy, compatible with any backend."""

    def __init__(self, input_dim=4, hidden_dim=8, output_dim=3):
        self.w1 = ivy.random_uniform(shape=(input_dim, hidden_dim), low=-0.5, high=0.5)
        self.b1 = ivy.zeros((hidden_dim,))
        self.w2 = ivy.random_uniform(shape=(hidden_dim, output_dim), low=-0.5, high=0.5)
        self.b2 = ivy.zeros((output_dim,))

    def forward(self, x):
        """Execute forward propagation using Ivy operations."""
        hidden = ivy.matmul(x, self.w1) + self.b1
        hidden = ivy.relu(hidden)
        logits = ivy.matmul(hidden, self.w2) + self.b2
        return ivy.softmax(logits)

    def train_step(self, x, y, lr=0.01):
        """Perform a training iteration with manual gradient updates."""
        predictions = self.forward(x)
        loss = -ivy.mean(ivy.sum(y * ivy.log(predictions + 1e-8), axis=-1))
        error = predictions - y

        hidden_activated = ivy.relu(ivy.matmul(x, self.w1) + self.b1)
        dw2 = ivy.matmul(ivy.permute_dims(hidden_activated, (1, 0)), error) / x.shape[0]
        db2 = ivy.mean(error, axis=0)

        self.w2 -= lr * dw2
        self.b2 -= lr * db2

        return loss

def run_framework_agnostic_demo():
    """Showcase the neural network running on various backends."""
    print("n" + "="*70)
    print("PART 1: Backend-Agnostic Neural Network")
    print("="*70)

    X = np.random.randn(100, 4).astype(np.float32)
    y = np.eye(3)[np.random.randint(0, 3, 100)].astype(np.float32)

    backends = ['numpy', 'torch', 'tensorflow', 'jax']
    results = {}

    for backend in backends:
        try:
            ivy.set_backend(backend)
            if backend == 'jax':
                import jax
                jax.config.update('jax_enable_x64', True)

            print(f"n🔄 Running on {backend.upper()} backend...")

            X_ivy = ivy.array(X)
            y_ivy = ivy.array(y)

            model = IvyNeuralNetwork()

            start = time.time()
            for _ in range(50):
                loss = model.train_step(X_ivy, y_ivy, lr=0.1)
            duration = time.time() - start

            preds = model.forward(X_ivy)
            accuracy = ivy.mean(ivy.astype(ivy.argmax(preds, axis=-1) == ivy.argmax(y_ivy, axis=-1), 'float32'))

            results[backend] = {
                'loss': float(ivy.to_numpy(loss)),
                'accuracy': float(ivy.to_numpy(accuracy)),
                'time': duration
            }

            print(f"   Final Loss: {results[backend]['loss']:.4f}")
            print(f"   Accuracy: {results[backend]['accuracy']:.2%}")
            print(f"   Training Time: {results[backend]['time']:.3f} seconds")

        except Exception as e:
            print(f"   ⚠️ Error with {backend}: {str(e)[:80]}")
            results[backend] = None

    ivy.unset_backend()
    return results

Seamless Code Transpilation Across Frameworks

Next, we demonstrate Ivy’s ability to facilitate code transpilation, enabling a function originally written in PyTorch to be replicated precisely in TensorFlow, NumPy, and JAX. This interoperability is achieved through Ivy’s unified API, which abstracts framework-specific syntax and semantics, ensuring consistent outputs regardless of the backend.

def transpilation_demo():
    """Illustrate code transpilation from PyTorch to other frameworks."""
    print("n" + "="*70)
    print("PART 2: Cross-Framework Transpilation")
    print("="*70)

    try:
        import torch
        import tensorflow as tf

        def pytorch_func(x):
            """Simple PyTorch operation."""
            return torch.mean(torch.relu(x * 2.0 + 1.0))

        x_torch = torch.randn(10, 5)

        print("n📦 Original PyTorch function output:")
        torch_result = pytorch_func(x_torch)
        print(f"   PyTorch result: {torch_result.item():.6f}")

        print("n🔄 Transpilation demonstration:")
        print("   Note: ivy.transpile() is powerful but best suited for traced functions.")
        print("   For clarity, we use Ivy's unified API for equivalent computations.")

        x_np = x_torch.numpy()

        for backend in ['numpy', 'tensorflow', 'jax']:
            ivy.set_backend(backend)
            if backend == 'jax':
                import jax
                jax.config.update('jax_enable_x64', True)

            x_ivy = ivy.array(x_np)
            result = ivy.mean(ivy.relu(x_ivy * 2.0 + 1.0))
            print(f"   {backend.capitalize()} result: {float(ivy.to_numpy(result)):.6f}")

        print("n   ✅ All backend results align within numerical precision.")

        ivy.unset_backend()

    except Exception as e:
        print(f"⚠️ Transpilation demo error: {str(e)[:80]}")

Unified API: Consistency Across Diverse Operations

We further validate Ivy’s unified API by executing a variety of mathematical, neural network, and statistical operations across multiple backends. This confirms that Ivy maintains consistent syntax and output, simplifying multi-framework development into a single, coherent interface.

def unified_api_demo():
    """Demonstrate Ivy's unified API across different operations."""
    print("n" + "="*70)
    print("PART 3: Unified API Demonstration")
    print("="*70)

    operations = [
        ("Matrix Multiplication", lambda x: ivy.matmul(x, ivy.permute_dims(x, (1, 0)))),
        ("Element-wise Arithmetic", lambda x: ivy.add(ivy.multiply(x, x), 2)),
        ("Reduction Operations", lambda x: ivy.mean(ivy.sum(x, axis=0))),
        ("Neural Network Activation", lambda x: ivy.mean(ivy.relu(x))),
        ("Statistical Computation", lambda x: ivy.std(x)),
        ("Broadcasting Example", lambda x: ivy.multiply(x, ivy.array([1.0, 2.0, 3.0, 4.0]))),
    ]

    X = np.random.randn(5, 4).astype(np.float32)

    for op_name, op_func in operations:
        print(f"n🔧 {op_name}:")
        for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
            try:
                ivy.set_backend(backend)
                if backend == 'jax':
                    import jax
                    jax.config.update('jax_enable_x64', True)

                x_ivy = ivy.array(X)
                result = op_func(x_ivy)
                result_np = ivy.to_numpy(result)

                if result_np.shape == ():
                    print(f"   {backend:12s}: scalar = {float(result_np):.4f}")
                else:
                    print(f"   {backend:12s}: shape={result_np.shape}, mean={np.mean(result_np):.4f}")

            except Exception as e:
                print(f"   {backend:12s}: ⚠️ {str(e)[:60]}")

        ivy.unset_backend()

Exploring Ivy’s Advanced Capabilities

Beyond basic operations, Ivy offers powerful features such as ivy.Container for managing nested data structures, compliance with the Array API standard across backends, and the ability to chain complex multi-step computations. These capabilities enable scalable and maintainable model development.

def advanced_features_demo():
    """Showcase Ivy's advanced functionalities."""
    print("n" + "="*70)
    print("PART 4: Advanced Ivy Features")
    print("="*70)

    print("n📦 Ivy Containers for structured parameter management:")
    try:
        ivy.set_backend('torch')

        container = ivy.Container({
            'layer1': {'weights': ivy.random_uniform((4, 8)), 'bias': ivy.zeros((8,))},
            'layer2': {'weights': ivy.random_uniform((8, 3)), 'bias': ivy.zeros((3,))}
        })

        print(f"   Container keys: {list(container.keys())}")
        print(f"   Layer1 weights shape: {container['layer1']['weights'].shape}")
        print(f"   Layer2 bias shape: {container['layer2']['bias'].shape}")

        def double_values(x, _):
            return x * 2.0

        scaled_container = container.cont_map(double_values)
        print("   ✅ Scaled all tensors within the container.")

    except Exception as e:
        print(f"   ⚠️ Container demo error: {str(e)[:80]}")

    print("n🔗 Array API compliance across backends:")
    tested_backends = []
    for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
        try:
            ivy.set_backend(backend)
            if backend == 'jax':
                import jax
                jax.config.update('jax_enable_x64', True)

            x = ivy.array([1.0, 2.0, 3.0])
            y = ivy.array([4.0, 5.0, 6.0])

            result = ivy.sqrt(ivy.square(x) + ivy.square(y))
            print(f"   {backend:12s}: L2 norm calculation successful ✅")
            tested_backends.append(backend)
        except Exception as e:
            print(f"   {backend:12s}: ⚠️ {str(e)[:50]}")

    print(f"n   Successfully validated {len(tested_backends)} backends.")

    print("n🎯 Executing complex chained operations:")
    try:
        ivy.set_backend('torch')

        x = ivy.random_uniform((10, 5), low=0, high=1)
        result = ivy.mean(ivy.relu(ivy.matmul(x, ivy.permute_dims(x, (1, 0)))), axis=0)

        print(f"   Operation chain: matmul → relu → mean")
        print(f"   Input shape: (10, 5), Output shape: {result.shape}")
        print("   ✅ Complex operation executed successfully.")

    except Exception as e:
        print(f"   ⚠️ Error during complex operation: {str(e)[:80]}")

    ivy.unset_backend()

Performance Benchmarking Across Frameworks

To assess real-world efficiency, we benchmark a composite operation involving matrix multiplication, ReLU activation, mean reduction, and summation across NumPy, PyTorch, TensorFlow, and JAX. Each backend is warmed up before timing 50 iterations, providing insights into latency and throughput for informed backend selection.

def benchmark_operation(op_func, x, iterations=50):
    """Measure execution time of a given operation."""
    start_time = time.time()
    for _ in range(iterations):
        _ = op_func(x)
    return time.time() - start_time

def performance_benchmark():
    """Compare execution speed across different backends."""
    print("n" + "="*70)
    print("PART 5: Performance Benchmarking")
    print("="*70)

    X = np.random.randn(100, 100).astype(np.float32)

    def complex_op(x):
        z = ivy.matmul(x, ivy.permute_dims(x, (1, 0)))
        z = ivy.relu(z)
        z = ivy.mean(z, axis=0)
        return ivy.sum(z)

    print("n⏱️ Benchmarking matrix operations (50 iterations):")
    print("   Operation sequence: matmul → relu → mean → sum")

    for backend in ['numpy', 'torch', 'tensorflow', 'jax']:
        try:
            ivy.set_backend(backend)
            if backend == 'jax':
                import jax
                jax.config.update('jax_enable_x64', True)

            x_ivy = ivy.array(X)
            _ = complex_op(x_ivy)  # Warm-up

            elapsed = benchmark_operation(complex_op, x_ivy, iterations=50)
            print(f"   {backend:12s}: {elapsed:.4f} seconds total ({elapsed/50*1000:.2f} ms per iteration)")

        except Exception as e:
            print(f"   {backend:12s}: ⚠️ {str(e)[:60]}")

    ivy.unset_backend()

Summary and Future Directions

This tutorial has showcased Ivy’s transformative approach to machine learning development: write your code once and deploy it seamlessly across multiple frameworks. We observed consistent model behavior, effortless backend switching, and comparable performance across NumPy, PyTorch, TensorFlow, and JAX. Ivy’s unified API, advanced container management, and graph optimization tools empower developers to build modular, efficient, and portable ML solutions.

Looking ahead, consider leveraging ivy.Container for organizing complex model parameters, exploring ivy.trace_graph() to optimize computation graphs, and experimenting with different backends to identify the best fit for your workloads. Ivy’s growing ecosystem and comprehensive documentation make it an invaluable asset for future-proof machine learning projects.


Ready to elevate your machine learning workflow? Dive deeper into Ivy’s capabilities and join a vibrant community of developers embracing framework-agnostic innovation.

Exit mobile version