Sunstreaker Save

以jax为后端的类似keras的框架

Project README
logo

Sunstreaker

源码清晰明了,使用简单好搞

目标

  • 源码清晰简洁,利于算法学习与实验
  • 快速实验新改进想法
  • 快速复现新论文
  • 快速分布式训练一个大模型
  • 快速使用开源模型权重

说明

  • 本项目采用小步快走的形式,欢迎start,但不建议fork,因为更新速度比较快
  • 本项目用于学习与实验,切勿用于生产

欢迎关注公众号:无数据不智能

logo

安装

tensorflow只是加载demo数据需要,也可以不装

windows

  1. 安装jax
    • cpu
    pip install jax[cpu]==0.3.14 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
    
    • gpu
    pip install jax[cuda111]==0.3.14 -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
    
  2. 安装Graphviz
     pip install --global-option=build_ext `
                   --global-option="-IC:\Program Files\Graphviz\include" `
                   --global-option="-LC:\Program Files\Graphviz\lib" `
                   pygraphviz
    
  3. pip install -r requirements.txt
  4. pip install sunstreaker

linux

  1. 安装jax

    • cpu
    pip install --upgrade jax[cpu]==0.3.14
    
    • gpu
    pip install --upgrade jax[cuda]==0.3.14 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    
  2. 安装Graphviz

  3. pip install -r requirements.txt

  4. pip install sunstreaker

使用

用tensorflow_datasets搞些数据

import tensorflow as tf

tf.config.set_visible_devices([], 'GPU')
import math
import asyncio
import tensorflow_datasets as tfds
from sunstreaker.data import Dataloader
from sunstreaker.layers import Flatten
from sunstreaker.layers.activations import Softmax
import jax.numpy as jnp
from sunstreaker.losses import categorical_crossentropy
from sunstreaker.metrics import categorical_accuracy
from sunstreaker.optimizers import RMSProp


def load(batch_size: int, func):
   async def tfds_load_data() -> Dataloader:
      ds, info = tfds.load(name="mnist", split=["train", "test"], as_supervised=True, with_info=True,
                           shuffle_files=True, batch_size=batch_size)
      train_ds, valid_ds = ds
      train_ds, valid_ds = func(train_ds), func(train_ds)
      train_ds, valid_ds = train_ds.cache().repeat(), valid_ds.cache().repeat()
      input_shape = tuple(list(info.features["image"].shape))
      num_train_batches = math.ceil(info.splits["train"].num_examples / batch_size)
      num_val_batches = math.ceil(info.splits["test"].num_examples / batch_size)
      return Dataloader(
         train_data=iter(tfds.as_numpy(train_ds)), val_data=iter(tfds.as_numpy(valid_ds)),
         input_shape=input_shape, batch_size=batch_size,
         num_train_batches=num_train_batches, num_val_batches=num_val_batches
      )

   return asyncio.run(tfds_load_data())


def load_dataset(batch_size: int):
   def func(ds):
      return ds.map(lambda x, y: (tf.divide(tf.cast(x, dtype=tf.float32), 255.0), tf.one_hot(y, depth=10)))

   return asyncio.run(load(batch_size, func))


def load_dataset_muti(batch_size: int):
   def func(ds):
      return ds.map(lambda x, y: ({"img": tf.divide(tf.cast(x, dtype=tf.float32), 255.0)}, {"out1": tf.one_hot(y, depth=10)}))

   return asyncio.run(load(batch_size, func))

序贯式编程

from sunstreaker.engine.sequential import Model

data = load_dataset(batch_size=1024)
model = Model([Input(input_shape=(28, 28, 1)), Flatten(), Dense(100), Dense(10), Softmax()])

函数式编程

data = load_dataset_muti(batch_size=1024)
inputs = Input(input_shape=(28, 28, 1), name="img")
flatten = Flatten()(inputs)
dense1 = Dense(100, activation='relu')(flatten)
dense2 = Dense(10, use_bias=False)(dense1)
outputs = Softmax(name="out1")(dense2)

from sunstreaker.engine.functional import Model

model = Model(inputs=inputs, outputs=outputs)

当你是一个老手

from sunstreaker import Model

data = load_dataset(batch_size=1024)

class MyModel(Model):
    def build(self, rng=None):
        self.W = self.add_weight((784, 10))
        self.flatten = Flatten()
        self.softmax = Softmax()
        return (10,), [(self.W,)]

    def call(self, params, inputs, trainable=True, **kwargs):
        self.W, = params[0]
        x = self.flatten.forward(params=[], inputs=inputs)
        x = jnp.dot(x, self.W)
        y = self.softmax.forward(params=[], inputs=x)
        return y


model = MyModel()

编译、训练、保存

model.compile(loss=categorical_crossentropy, optimizer=RMSProp(lr=0.001), metrics=[categorical_accuracy])
model.fit(data, epochs=10)
model.save("tfds_mnist_v2")

模型结构打印

model.summary()
model.plot_model() 
+--------+-----------+---------+-------------+--------------+
| number | name      | class   | input_shape | output_shape |
+--------+-----------+---------+-------------+--------------+
| 0      | input_0   | Input   | (28, 28, 1) | (28, 28, 1)  |
| 1      | flatten_1 | Flatten | (28, 28, 1) | (784,)       |
| 2      | dense_2   | Dense   | (784,)      | (100,)       |
| 3      | dense_4   | Dense   | (100,)      | (10,)        |
| 4      | softmax_6 | Softmax | (10,)       | (10,)        |
+--------+-----------+---------+-------------+--------------+
logo

损失与评价可视化

model.plot_losses()
model.plot_accuracy()
logo
logo

功能

0.0.1.dev更新

activations layers losses metrics optimizers
Linear Dense binary_crossentropy binary_accuracy SGD
Softmax Flatten categorical_crossentropy accuracy SM3
Relu Dropout mean_squared_error categorical_accuracy Adagrad
Sigmoid Conv2D mean_absolute_error sparse_categorical_accuracy Adam
Elu MaxPool2D mean_squared_logarithmic_error cosine_similarity_accuracy Adamax
LeakyRelu AveragePooling2D hinge top_k_categorical_accuracy RMSProp
Gelu GRU kl_divergence sparse_top_k_categorical_accuracy FTRL
huber

0.0.2.dev更新

layers losses
Embedding l2_error
Lambda
Add
Concatenate
Dot
Multiply
LayerNormalization
InstanceNormalization
BatchNormalization
GroupNormalization
LocalResponseNormalization
UpSampling2D

0.0.3.dev更新

initializer activations
zeros Swish
ones
constant
uniform
normal
orthogonal
LecunUniform
LecunNormal
GlorotNormal
GlorotUniform
HeNormal
HeUniform
KaimingUniform
KaimingNormal
XavierNormal
XavierUniform
Identity

0.0.4.dev更新

内核改动

  1. Layer call 函数不再需要传入params,build输出不再需要输出params,以dense为例

    class Dense(Layer):
        def __init__(self, units, activation=None, use_bias=True, kernel_initializer=GlorotUniform(), bias_initializer=Zeros(), **kwargs):
            super().__init__(**kwargs)
            self.use_bias = use_bias
            self.activation = activations.get(activation)()
            self.units = int(units) if not isinstance(units, int) else units
            self.kernel_initializer = kernel_initializer
            self.bias_initializer = bias_initializer
    
        def build(self):
            output_shape = self.input_shape[:-1] + (self.units,)
            self.add_weight("kernel", (self.input_shape[-1], self.units), initializer=self.kernel_initializer, seed=k1)
            if self.use_bias:
                self.add_weight("bias", (self.units,), initializer=self.bias_initializer, seed=k2)
            return output_shape
    
        def call(self, inputs, **kwargs):
            kernel = self.get_weight("kernel")
            if self.use_bias:
                bias = self.get_weight("bias")
                outputs = jnp.dot(inputs, kernel) + bias
            else:
                outputs = jnp.dot(inputs, kernel)
            outputs = self.activation.forward(params=None, inputs=outputs)
            return outputs
    
  2. Model params变为有序字典,方便大模型参数加载

  3. build不再需要输入随机种子,由内核自动分配

0.0.5.dev更新

application layers
transformers/bert MultiHeadAttention
PositionEmbedding
FeedForward
ScaleOffset
Activation

0.0.6.dev更新

application optimizers
diffusion/DDPM AdamW

引用

Open Source Agenda is not affiliated with "Sunstreaker" Project. README Source: duyongan/sunstreaker