MLX: An array framework for Apple silicon
mx.linalg.svd
(CPU only)nn.RNN
, nn.LSTM
, nn.GRU
mx.fast.scaled_dot_product_attention
fused opmx.fast.scaled_dot_product_attention
fused opmx.array
mx.topk
mx.where
properly handles inf
atleast_{1,2,3}d
accept any number of arraysnn.Upsample
layer
arange
throws on inf
inputslogsumexp
inf
edge caseinf
constantsmx.compile(function, shapeless=True)
mx.atleast_1d
, mx.atleast_2d
, mx.atleast_3d
tolist
with bfloat16
and float16
argmax
on M3mx.fast
subpackagemx.fast.rope
up to 20x faster
safetensors
bfloat16
quantizated matrix-vector multipliesmx.fast
subpackage with a fast RoPEmx.stream
to set the default deviceoptimizers.step_decay
optimizers.cosine_decay
opimtizers.exponential_decay
mx.compile
makes stuff go fast
mx.compile
function transformation__abs__
overload for abs
on arraysloc
and scale
in parameter for mx.random.normal
mx.var
to give inf
with doff >= nelem
nn.Sequential
mx.checkpoint
mx.checkpoint
mx.linalg.qr
mx.eval
mx.diag
, mx.diagonal
array.shape
is a Python tupleint64
and uint64
reductionssum
, prod
, max
, min
, all
, any
argmax
, argmin
inf
work, and fix mx.isinf
mx.full
NaN
in some binary ops
mx.logaddexp
, mx.maximum
, mx.minimum
mx.log1p
with inf
inputQ4_0
, Q4_1
, and Q8_0
Q4_0
, Q4_1
, and Q8_0
)Module.save_weights
supports safetensorsnn.init
package with several commonly used neural network initializersAdafactor
in nn.optimizers
isinf
and friends for integer typesint64
, uint
, and float32
0
inputsinf
reads in gemv
mx.arange
crashes on NaN inputsmx.isnan
, mx.isinf
, isposinf
, isneginf
mx.tile
scatter_min
and scatter_max
mx.eye
mx.round
to follow NumPy which rounds to evenat[]
syntax for scatter style operations: x.at[idx].add(y)
, (min
, max
, prod
, etc)mx.array([x, y])
)mx.inner
, mx.outer
+=
, *=
, -=
, ...)mx.pi
, mx.inf
, mx.newaxis
, …)cosine_similarity
lossRoPE
and ALiBi
tri
retain_graph