Spyx#

Spyx is a JAX-based SNN/Deep learning framework that enables fully JIT compiled optimization of models.

import spyx
import spyx.nn as snn

import jax
import jax.numpy as jnp

import nir

Import a NIR graph to Spyx:#

# Load the NIR graph from disk
nir_graph = nir.read("saved_network.nir")

# Use the nir_graph and a sample of your input (for shape information)
# dt is used to scale the weights properly if the imported network was trained
# in a different simulator where dt is not necessarily 1.
SNN, params = spyx.nir.from_nir(nir_graph, sample_batch, dt=1)

# Use it as you wish:
SNN.apply(params, sample_batch)

Export a network from Spyx to a NIR graph:#

# Some operations may have rearranged the PyTree (dictionary) that stores
# the SNN weights, so the helper function reorders the dict
# to allow for proper exportation. 
export_params = spyx.nir.reorder_layers(init_params, optimized_params)

# provide the params to export along with the input/output sizes and the desired
# time resolution; this is so you can load it up with the proper dt in other
# frameworks that allow you to specify smaller time intervals
# whereas Spyx assumes every timestep to be 1 to avoid units.
nir_graph = spyx.nir.to_nir(export_params, input_shape, output_shape, dt)

# Write the NIR graph to the desired filepath
nir.write("./spyx_shd.nir", nir_graph)