NDScala Save

N-dimensional / multi-dimensional arrays (tensors) in Scala 3. Think NumPy ndarray / PyTorch Tensor but type-safe over shapes, array/axis labels & numeric data types

Project README

Training a (shape-safe) neural network in 10 lines:

In NDScala:

//After some setup
//Declaring types and their corresponding values
type Mat10kX10k = 10000 #: 10000 #:SNil
type AxisLabels = "AxisLabel" ##: "AxisLabel" ##: TSNil
val mat10kX10k = shapeOf[Mat10kX10k]
val axisLabels = tensorShapeDenotationOf[AxisLabels]

val ones = Tensor(Array.fill(100000000)(1.0f),"TensorLabel",axisLabels, mat10kX10k)

def train(x: Tensor[Float, ("TensorLabel", AxisLabels, Mat10kX10k)],
          y: Tensor[Float, ("TensorLabel", AxisLabels, Mat10kX10k)],
          w0: Tensor[Float, ("TensorLabel", AxisLabels, Mat10kX10k)],
          w1: Tensor[Float, ("TensorLabel", AxisLabels, Mat10kX10k)],
          iter: Int): Tuple2[Tensor[Float, ("TensorLabel", AxisLabels, Mat10kX10k)],
                             Tensor[Float, ("TensorLabel", AxisLabels, Mat10kX10k)]] =
    if iter == 0 then (w0, w1)
    else
        val l1 =  (x.matmul(w0)).sigmoid()
        val l2 = (l1.matmul(w1)).sigmoid()
        val error = y - l2
        val l2Delta = (error) * (l2 * (ones - l2))
        val l1Delta =  (l2Delta.matmul(w1.transpose))
        val w1New = w1 + (((l1.transpose).matmul(l2Delta)))
        val w0New = w0 + (((x.transpose).matmul(l1Delta)))
        train(x,y,w0New,w1New,iter-1)

And for reference, in NumPy, in 10 lines:

def train(X,Y,iter): 
    syn0 = 2*np.random.random((10000,10000)).astype('float32') - 1
    syn1 = 2*np.random.random((10000,1000)).astype('float32') - 1
    for j in range(iter): 
        l1 = 1/(1+np.exp(-(np.dot(X,syn0))))  
        l2 = 1/(1+np.exp(-(np.dot(l1,syn1)))) 
        error = y - l2
        l2_delta = (error)*(l2*(1-l2))
        l1_delta = l2_delta.dot(syn1.T) * (l1 * (1-l1))
        syn1 += l1.T.dot(l2_delta)
        syn0 += X.T.dot(l1_delta) 

The run time of the NDScala version is ~80% of that of NumPy w/MKL

The PyTorch equivalent is slightly faster, at ~85% of the NDScala version run time. This can be accounted for by the copy overhead of passing data between the JVM and native memory.

Open Source Agenda is not affiliated with "NDScala" Project. README Source: SciScala/NDScala

Open Source Agenda Badge

Open Source Agenda Rating