Skip to content

STDP Training ​

This example demonstrates a complete plasticity training loop: a baseline phase without learning, an STDP training phase with periodic synaptic scaling, and a frozen-weight evaluation phase. It produces plots of weight evolution, weight distributions, spike rasters, and per-synapse weight changes.

See Synaptic Plasticity for background on how the learning rule works.

py
# /// script
# requires-python = ">=3.11"
# dependencies = [
#   "livn",
#   "matplotlib",
# ]
# ///

import os

os.environ["LIVN_BACKEND"] = "neuron"

import matplotlib.pyplot as plt
import numpy as np

from livn.env import Env


BASELINE_MS = 200  # no learning warmup
TRAINING_MS = 1000
EVAL_MS = 200  # frozen weights
WEIGHT_REC_DT = 1.0


env = Env("systems/graphs/EI1").init()
env.apply_model_defaults()
env.record_spikes()
env.record_voltage()
env.record_weights(dt=WEIGHT_REC_DT)

print(f"Phase 1: Baseline ({BASELINE_MS} ms, no plasticity)")
it_base, t_base, iv_base, v_base, *_ = env.run(BASELINE_MS)

weights_before = env.get_weights()
print(f" {len(weights_before)} STDP-capable synapses")

print(f"Phase 2: Training ({TRAINING_MS} ms, STDP + synaptic scaling)")
env.enable_plasticity()

it_train_all, t_train_all = [], []
iv_train_all, v_train_all = [], []
remaining = TRAINING_MS
while remaining > 0:
    chunk = min(100, remaining)
    it_chunk, t_chunk, iv_chunk, v_chunk, *_ = env.run(chunk)
    it_train_all.append(np.asarray(it_chunk))
    t_train_all.append(np.asarray(t_chunk))
    env.normalize_weights()
    remaining -= chunk

it_train = np.concatenate(it_train_all) if it_train_all else np.array([])
t_train = np.concatenate(t_train_all) if t_train_all else np.array([])

weights_after_training = env.get_weights()

print(f"Phase 3: Evaluation ({EVAL_MS} ms, plasticity frozen)")
env.disable_plasticity()

it_eval, t_eval, iv_eval, v_eval, *_ = env.run(EVAL_MS)

weights_final = env.get_weights()

w_before = np.array(list(weights_before.values()))
w_after = np.array(list(weights_after_training.values()))
w_final = np.array(list(weights_final.values()))

if env.w_recs:
    max_dev = 0.0
    for key, vec in env.w_recs.items():
        trace = np.array(vec.as_numpy())
        dev = np.max(np.abs(trace - 1.0))
        max_dev = max(max_dev, dev)
    print(f" Peak weight deviation during training: {max_dev:.4f}")

frozen = np.sum(np.abs(w_final - w_after) > 1e-9)
print(f" Synapses drifted during eval (should be 0): {frozen}")


fig, axes = plt.subplots(2, 2, figsize=(14, 9))

# Weight evolution over training
ax = axes[0, 0]
n_traces = min(20, len(env.w_recs))
for i, (key, vec) in enumerate(env.w_recs.items()):
    if i >= n_traces:
        break
    trace = np.array(vec.as_numpy())
    time_axis = np.arange(len(trace)) * WEIGHT_REC_DT
    ax.plot(time_axis, trace, alpha=0.6, linewidth=0.8)

ax.set_xlabel("Time (ms)")
ax.set_ylabel("Weight multiplier (w)")
ax.set_title(f"Weight evolution (first {n_traces} synapses, scaled every 100 ms)")
ax.axhline(y=1.0, color="k", linestyle="--", alpha=0.3, label="initial w=1")
ax.axvline(x=BASELINE_MS, color="red", linestyle="-", alpha=0.3, label="STDP on")
ax.axvline(
    x=BASELINE_MS + TRAINING_MS,
    color="orange",
    linestyle="-",
    alpha=0.3,
    label="STDP off",
)
ax.legend(fontsize=8)

# Weight distributions before/after
ax = axes[0, 1]
bins = np.linspace(
    min(w_before.min(), w_after.min()) - 0.05,
    max(w_before.max(), w_after.max()) + 0.05,
    40,
)
ax.hist(w_before, bins=bins, alpha=0.6, label="Before training", color="steelblue")
ax.hist(w_after, bins=bins, alpha=0.6, label="After training", color="coral")
ax.axvline(x=1.0, color="k", linestyle="--", alpha=0.3)
ax.set_xlabel("Weight multiplier (w)")
ax.set_ylabel("Count")
ax.set_title("Weight distribution")
ax.legend(fontsize=8)

ax = axes[1, 0]
total_offset = 0

for phase_label, it_data, t_data, t_offset, color in [
    ("Baseline", it_base, t_base, 0, "gray"),
    ("Training", it_train, t_train, BASELINE_MS, "tab:blue"),
    ("Eval", it_eval, t_eval, BASELINE_MS + TRAINING_MS, "tab:green"),
]:
    if len(it_data) > 0 and len(t_data) > 0:
        ax.scatter(
            np.array(t_data) + t_offset,
            np.array(it_data),
            s=1,
            alpha=0.5,
            color=color,
            label=phase_label,
        )

ax.axvline(x=BASELINE_MS, color="red", linestyle="-", alpha=0.5, label="STDP on")
ax.axvline(
    x=BASELINE_MS + TRAINING_MS,
    color="orange",
    linestyle="-",
    alpha=0.5,
    label="STDP off",
)
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Neuron GID")
ax.set_title("Spike raster")
ax.legend(fontsize=7, markerscale=5)

ax = axes[1, 1]
delta_w = w_after - w_before
sort_idx = np.argsort(delta_w)
colors = ["coral" if d > 0 else "steelblue" for d in delta_w[sort_idx]]
ax.bar(range(len(delta_w)), delta_w[sort_idx], color=colors, width=1.0)
ax.axhline(y=0, color="k", linewidth=0.5)
ax.set_xlabel("Synapse (sorted)")
ax.set_ylabel("delta w")
ax.set_title("Per-synapse weight change")

plt.tight_layout()
plt.savefig("stdp_training.png", dpi=150)
print("\nPlot saved to stdp_training.png")
plt.show()

env.close()

Released under the MIT License.