Home Industries Education A Coding Implementation to Build Neural Memory Agents with Differentiable Memory, Meta-Learning,...

A Coding Implementation to Build Neural Memory Agents with Differentiable Memory, Meta-Learning, and Experience Replay for Continual Adaptation in Dynamic Environments

0

Building a Neural Memory Agent for Lifelong Learning

This guide delves into the design of a neural memory agent that continuously acquires new knowledge without erasing previous experiences. We develop a memory-augmented neural network architecture by combining a Differentiable Neural Computer (DNC) with prioritized experience replay and meta-learning techniques. Implemented in PyTorch, this system leverages content-based memory addressing and replay strategies to effectively combat catastrophic forgetting, ensuring robust performance across diverse sequential tasks.

Setting Up the Memory Configuration

We start by importing necessary libraries and defining a configuration class that specifies key parameters for the neural memory system. These include the total memory slots, the dimensionality of each memory vector, and the number of read and write heads. This configuration forms the backbone of our differentiable memory, dictating how information is stored and accessed during training.

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
from dataclasses import dataclass

@dataclass
class MemoryConfig:
    memorysize: int = 128
    memorydim: int = 64
    numreadheads: int = 4
    numwriteheads: int = 1

Neural Memory Bank and Controller: Core Components

The Neural Memory Bank acts as a dynamic storage module, enabling content-based retrieval and modification of memory slots. It uses cosine similarity to address memory locations and supports both reading and writing operations with differentiable mechanisms.

The Memory Controller is an LSTM-based network that interfaces with the memory bank. It generates keys and strengths for reading and writing, orchestrating how information flows between the input, memory, and output layers. This interaction allows the agent to flexibly recall relevant data and update its memory in response to new inputs.

class NeuralMemoryBank(nn.Module):
    def init(self, config: MemoryConfig):
        super().init()
        self.memorysize = config.memorysize
        self.memorydim = config.memorydim
        self.numreadheads = config.numreadheads
        self.registerbuffer('memory', torch.zeros(config.memorysize, config.memorydim))
        self.registerbuffer('usage', torch.zeros(config.memorysize))

    def contentaddressing(self, key, beta):
        keynorm = F.normalize(key, dim=-1)
        memnorm = F.normalize(self.memory, dim=-1)
        similarity = torch.matmul(keynorm, memnorm.t())
        return F.softmax(beta  similarity, dim=-1)

    def write(self, writekey, writevector, erasevector, writestrength):
        writeweights = self.contentaddressing(writekey, writestrength)
        erase = torch.outer(writeweights.squeeze(), erasevector.squeeze())
        self.memory = (self.memory  (1 - erase)).detach()
        add = torch.outer(writeweights.squeeze(), writevector.squeeze())
        self.memory = (self.memory + add).detach()
        self.usage = (0.99  self.usage + writeweights.squeeze()).detach()

    def read(self, readkeys, readstrengths):
        reads = []
        for i in range(self.numreadheads):
            weights = self.contentaddressing(readkeys[i], readstrengths[i])
            readvector = torch.matmul(weights, self.memory)
            reads.append(readvector)
        return torch.cat(reads, dim=-1)


class MemoryController(nn.Module):
    def init(self, inputdim, hiddendim, memoryconfig: MemoryConfig):
        super().init()
        self.hiddendim = hiddendim
        self.memoryconfig = memoryconfig
        self.lstm = nn.LSTM(inputdim, hiddendim, batchfirst=True)
        totalreaddim = memoryconfig.numreadheads  memoryconfig.memorydim
        self.readkeys = nn.Linear(hiddendim, memoryconfig.numreadheads  memoryconfig.memorydim)
        self.readstrengths = nn.Linear(hiddendim, memoryconfig.numreadheads)
        self.writekey = nn.Linear(hiddendim, memoryconfig.memorydim)
        self.writevector = nn.Linear(hiddendim, memoryconfig.memorydim)
        self.erasevector = nn.Linear(hiddendim, memoryconfig.memorydim)
        self.writestrength = nn.Linear(hiddendim, 1)
        self.output = nn.Linear(hiddendim + totalreaddim, inputdim)

    def forward(self, x, memorybank, hidden=None):
        lstmout, hidden = self.lstm(x.unsqueeze(0), hidden)
        controllerstate = lstmout.squeeze(0)
        readk = self.readkeys(controllerstate).view(self.memoryconfig.numreadheads, -1)
        reads = F.softplus(self.readstrengths(controllerstate))
        writek = self.writekey(controllerstate)
        writev = torch.tanh(self.writevector(controllerstate))
        erasev = torch.sigmoid(self.erasevector(controllerstate))
        writes = F.softplus(self.writestrength(controllerstate))
        readvectors = memorybank.read(readk, reads)
        memorybank.write(writek, writev, erasev, writes)
        combined = torch.cat([controllerstate, readvectors], dim=-1)
        output = self.output(combined)
        return output, hidden

Enhancing Learning with Experience Replay and Meta-Learning

To bolster the agent’s continual learning capabilities, we incorporate two critical components:

  • Experience Replay: This mechanism stores past experiences in a prioritized buffer, allowing the model to revisit and reinforce previous knowledge. Prioritized sampling ensures that more significant experiences are replayed more frequently, which helps mitigate forgetting.
  • Meta-Learner: Inspired by Model-Agnostic Meta-Learning (MAML), this module enables rapid adaptation to new tasks by fine-tuning model parameters through a few gradient steps on support data.
class ExperienceReplay:
    def init(self, capacity=10000, alpha=0.6):
        self.capacity = capacity
        self.alpha = alpha
        self.buffer = deque(maxlen=capacity)
        self.priorities = deque(maxlen=capacity)

    def push(self, experience, priority=1.0):
        self.buffer.append(experience)
        self.priorities.append(priority  self.alpha)

    def sample(self, batchsize, beta=0.4):
        if len(self.buffer) == 0:
            return [], []
        probs = np.array(self.priorities)
        probs = probs / probs.sum()
        indices = np.random.choice(len(self.buffer), min(batchsize, len(self.buffer)), p=probs, replace=False)
        samples = [self.buffer[i] for i in indices]
        weights = (len(self.buffer)  probs[indices])  (-beta)
        weights = weights / weights.max()
        return samples, torch.FloatTensor(weights)


class MetaLearner(nn.Module):
    def init(self, model):
        super().init()
        self.model = model

    def adapt(self, supportx, supporty, numsteps=5, lr=0.01):
        adaptedparams = {name: param.clone() for name, param in self.model.namedparameters()}
        for  in range(numsteps):
            pred,  = self.model(supportx, self.model.memorybank)
            loss = F.mseloss(pred, supporty)
            grads = torch.autograd.grad(loss, self.model.parameters(), creategraph=True)
            adaptedparams = {name: param - lr  grad for (name, param), grad in zip(adaptedparams.items(), grads)}
        return adaptedparams

Integrating Components into a Continual Learning Agent

The Continual Learning Agent merges the memory bank, controller, replay buffer, and meta-learner into a cohesive system. It defines the training routine, which includes updating the model with new data, replaying past experiences to prevent forgetting, and evaluating performance on previously learned tasks.

class ContinualLearningAgent:
    def init(self, inputdim=64, hiddendim=128):
        self.config = MemoryConfig()
        self.memorybank = NeuralMemoryBank(self.config)
        self.controller = MemoryController(inputdim, hiddendim, self.config)
        self.replaybuffer = ExperienceReplay(capacity=5000)
        self.metalearner = MetaLearner(self.controller)
        self.optimizer = torch.optim.Adam(self.controller.parameters(), lr=0.001)
        self.taskhistory = []

    def trainstep(self, x, y, usereplay=True):
        self.optimizer.zerograd()
        pred,  = self.controller(x, self.memorybank)
        currentloss = F.mseloss(pred, y)
        self.replaybuffer.push((x.detach().clone(), y.detach().clone()), priority=currentloss.item() + 1e-6)
        totalloss = currentloss

        if usereplay and len(self.replaybuffer.buffer) > 16:
            samples, weights = self.replaybuffer.sample(8)
            for (replayx, replayy), weight in zip(samples, weights):
                with torch.enablegrad():
                    replaypred,  = self.controller(replayx, self.memorybank)
                    replayloss = F.mseloss(replaypred, replayy)
                    totalloss = totalloss + 0.3  replayloss  weight

        totalloss.backward()
        torch.nn.utils.clipgradnorm(self.controller.parameters(), 1.0)
        self.optimizer.step()
        return totalloss.item()

    def evaluate(self, testdata):
        self.controller.eval()
        totalerror = 0
        with torch.nograd():
            for x, y in testdata:
                pred,  = self.controller(x, self.memorybank)
                totalerror += F.mseloss(pred, y).item()
        self.controller.train()
        return totalerror / len(testdata)

Simulating Sequential Tasks for Continual Learning

To test the agent’s ability to learn multiple tasks sequentially, we generate synthetic datasets with distinct characteristics. Each task applies a different transformation to the input data, such as sinusoidal, cosine, or hyperbolic tangent functions, simulating diverse learning environments.

def createtaskdata(taskid, numsamples=100):
    torch.manualseed(taskid)
    x = torch.randn(numsamples, 64)
    if taskid == 0:
        y = torch.sin(x.mean(dim=1, keepdim=True).expand(-1, 64))
    elif taskid == 1:
        y = torch.cos(x.mean(dim=1, keepdim=True).expand(-1, 64))  0.5
    else:
        y = torch.tanh(x  0.5 + taskid)
    return [(x[i], y[i]) for i in range(numsamples)]

Running the Continual Learning Experiment

The following routine trains the agent on a series of tasks, periodically evaluating its performance on all previously learned tasks. This process highlights the agent’s ability to retain knowledge while adapting to new challenges. Visualizations include the state of the neural memory bank and the error trends across tasks.

def runcontinuallearningdemo():
    print("🧠 Neural Memory Agent - Lifelong Learning Demonstration")
    print("="  60)
    agent = ContinualLearningAgent()
    numtasks = 4
    results = {'tasks': [], 'withmemory': []}

    for taskid in range(numtasks):
        print(f"n📚 Training Task {taskid + 1} of {numtasks}")
        traindata = createtaskdata(taskid, numsamples=50)
        testdata = createtaskdata(taskid, numsamples=20)

        for epoch in range(20):
            totalloss = 0
            for x, y in traindata:
                loss = agent.trainstep(x, y, usereplay=(taskid > 0))
                totalloss += loss
            if epoch % 5 == 0:
                avgloss = totalloss / len(traindata)
                print(f"  Epoch {epoch:2d}: Loss = {avgloss:.4f}")

        print("n📊 Evaluating on all learned tasks:")
        for evaltaskid in range(taskid + 1):
            evaldata = createtaskdata(evaltaskid, numsamples=20)
            error = agent.evaluate(evaldata)
            print(f"    Task {evaltaskid + 1}: Test Error = {error:.4f}")
            if evaltaskid == taskid:
                results['tasks'].append(evaltaskid + 1)
                results['withmemory'].append(error)

    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    ax = axes[0]
    memorymatrix = agent.memorybank.memory.detach().numpy()
    im = ax.imshow(memorymatrix, aspect='auto', cmap='viridis')
    ax.settitle('Neural Memory Bank Activation', fontsize=14, fontweight='bold')
    ax.setxlabel('Memory Vector Dimension')
    ax.setylabel('Memory Slots')
    plt.colorbar(im, ax=ax)

    ax = axes[1]
    ax.plot(results['tasks'], results['withmemory'], marker='o', linewidth=2, markersize=8, label='With Memory Replay')
    ax.settitle('Performance Across Tasks', fontsize=14, fontweight='bold')
    ax.setxlabel('Task Number')
    ax.setylabel('Mean Squared Error')
    ax.legend()
    ax.grid(True, alpha=0.3)

    plt.tightlayout()
    plt.savefig('neuralmemoryresults.png', dpi=150, bboxinches='tight')
    print("n✅ Results saved as 'neuralmemoryresults.png'")
    plt.show()

    print("n" + "=" * 60)
    print("🎯 Key Takeaways:")
    print("  • The memory bank effectively compresses and stores task-specific information.")
    print("  • Prioritized experience replay significantly reduces forgetting.")
    print("  • The agent sustains high accuracy on earlier tasks while learning new ones.")
    print("  • Content-based addressing facilitates precise and efficient memory retrieval.")

if name == "main":
    runcontinuallearning_demo()

Summary

In this tutorial, we constructed a neural memory agent capable of lifelong learning by integrating differentiable memory, experience replay, and meta-learning. This architecture enables the agent to store and recall complex task representations, adapt swiftly to new challenges, and maintain performance on previously learned tasks. Such systems represent a promising direction for developing intelligent models that evolve continuously without succumbing to catastrophic forgetting.

Exit mobile version