Elixir Meets GPT-2: Implementing GPT-2 with Nx

A bumblebee sitting on a flower with more flowers in the background.
Sean Moriarity

Machine Learning Advisor

Sean Moriarity

Bumblebee is an Elixir library that gives you access to some very powerful models capable of incredible things. At times, it really can seem like magic. Bumblebee abstracts many of the details of building models, turning real-world inputs into tensors, and turning tensors into real-world outputs away from you. Bumblebee’s abstractions simplify your implementations; however, the library can feel like a black box.

In this post, we’ll attempt to peel back the layers of abstraction by implementing GPT-2 in pure Nx. We’ll also look at how text generation works, and how we can easily wrap our implementation in an Nx.Serving for scalable and distributed model serving.

This post is based on this fantastic blog post by Jay Mody which does the same thing but with NumPy. We won’t be code golfing here, but I think you’ll find the Nx implementation is just as straightforward as the NumPy implementation.

Before getting started, you’ll need to install a few dependencies:

Mix.install([
  {:nx, "~> 0.6"},
  {:exla, "~> 0.6"},
  {:safetensors, "~> 0.1"},
  {:tokenizers, "~> 0.3"}
])

Nx is Elixir’s numerical computing library. It is the foundation of every other library and implements the key data structure for machine learning and numerical computing: %Nx.Tensor. EXLA is a just-in-time compiler for Nx. It transforms Nx numerical definitions defn into optimized CPU and GPU programs. safetensors is an Elixir library that allows us to convert .safetensors files into collections of Nx tensors. We’ll use safetensors to load pre-trained GPT-2 parameters. Finally, the tokenizers library is an Elixir wrapper around the HuggingFace tokenizers library. It will allow us to use pre-trained tokenizers for encoding strings as tensors and decoding tensors as strings.

Loading Parameters

The first step in creating our GPT-2 model is loading pre-trained model parameters as Nx tensors. One of the things that makes Bumblebee so awesome is its ability to take pre-trained model parameters from the Python ecosystem and use them directly in Elixir-based implementations of the same models. There are two ways we do this:

  1. Converting PyTorch model parameters (.bin files representing Python pickles) into Nx tensors using the unpickler library from Dashbit.
  2. Converting safetensors model parameters (.safetensors files) into Nx tensors using the safetensors library.

Recently, HuggingFace has started converting a majority of pre-trained model parameters to the safetensors format because it offers a variety of benefits. Fortunately, it also simplifies a lot of our parameter conversion process. For this tutorial, we’ll be using the .safetensors version of the smallest GPT-2 pre-trained model. You can download these parameters here: https://huggingface.co/gpt2/blob/main/model.safetensors.

params = Safetensors.load!(File.read!("gpt2.safetensors"))
%{
  "h.4.mlp.c_fc.weight" => #Nx.Tensor<
    f32[768][3072]
    [
      [-4.077995545230806e-4, -0.1200379952788353, -0.012310190126299858, -0.24376054108142853, 0.1328510195016861, 0.13179974257946014, 0.02635245770215988, 0.057356610894203186, -0.06828179955482483, -0.01686907559633255, 0.049044106155633926, -0.3784016966819763, -0.03531080484390259, 0.43567171692848206, 0.02976839430630207, -0.06014109030365944, 0.18706151843070984, -0.050236742943525314, 0.11668948084115982, 0.05957753583788872, -0.14054043591022491, -0.013522407039999962, -0.06838822364807129, 0.06603340059518814, 0.1704917997121811, -0.04796731844544411, 0.13489077985286713, 0.09338540583848953, -0.4719774127006531, -0.023174606263637543, -0.035188376903533936, 0.0310268085449934, 0.10190644860267639, 0.12008801102638245, -0.10986845195293427, 0.24427762627601624, -0.26449140906333923, 0.03904227167367935, -0.0342307947576046, 0.003329918021336198, 0.10291037708520889, 0.022748058661818504, 0.012120273895561695, 0.017654111608862877, 0.061049483716487885, -0.25464147329330444, 0.017738372087478638, -0.1125129833817482, 0.019292231649160385, ...],
      ...
    ]
  >,
  ...
}

You’ll notice the model parameters consist of a map of strings to tensors. The map’s keys will have values like: h.4.mlp.c_fc.weight. The last value in the . delimited string represents the parameter name of that particular tensor. The preceeding values map to a nested modules. Essentially, if you had a module like:

class MyModule(nn.Module):
  def __init__(self, *args):
    self.fc1 = nn.Linear(64, 32)
  
  def forward(self, x):
    return F.relu(self.fc1(x))

in PyTorch, this would map to the parameter keys {module_name}.fc1.weight and {module_name}.fc1.bias}. In order to work with these parameters, we need to “unflatten” them so that they’re easy to work with in our pure Nx implementation. This is one of the first things Bumblebee does automatically for you — it maps this flattened representation of a model’s parameters to a representation Axon can use.

For this example, we’ll convert the flattened map into a nested map where each value represents the input parameters to a function or layer in our GPT-2 implementation:

blocks =
  Enum.reduce(params, %{}, fn {key, value}, acc ->
    case String.split(key, ".") do
      ["h", block_num, inner_block_name, layer_name, param_name] ->
        init = %{inner_block_name => %{layer_name => %{param_name => value}}}

        Map.update(acc, "block_#{block_num}", init, fn block_params ->
          inner_init = %{layer_name => %{param_name => value}}

          Map.update(block_params, inner_block_name, inner_init, fn inner_block_params ->
            layer_init = %{param_name => value}

            Map.update(inner_block_params, layer_name, layer_init, fn layer_params ->
              Map.put(layer_params, param_name, value)
            end)
          end)
        end)

      ["h", block_num, layer_name, param_name] ->
        init = %{layer_name => %{param_name => value}}

        Map.update(acc, "block_#{block_num}", init, fn block_params ->
          layer_init = %{param_name => value}

          Map.update(block_params, layer_name, layer_init, fn layer_params ->
            Map.put(layer_params, param_name, value)
          end)
        end)

      [layer_name, param_name] ->
        Map.update(acc, layer_name, %{param_name => value}, fn layer_params ->
          Map.put(layer_params, param_name, value)
        end)
    end
  end)
%{
  "block_0" => %{
    "attn" => %{
      "bias" => #Nx.Tensor<
        f32[1][1][1024][1024]
        [
          [
            [
              [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...],
              ...
            ]
          ]
        ]
      >,
      "c_attn" => %{...}
      ...
    }
  }
}

After running this, you’ll have a nested map which we can recurse and use to build our model.

Encoding Text

Another great convenience Bumblebee offers is its handling of pre/post-processing for machine-learning tasks for you. When dealing with text, this manifests in the form of tokenization. Language models rely on tokenizers to convert sequences of input into discrete numerical representations of the text. Most often, this comes in the form of subword tokenization. Models like GPT-2 rely on probabilistic techniques to build fixed-size vocabularies of subwords. These vocabularies are then used to translate text representations into sequences of integers and vice-versa.

Bumblebee makes use of HuggingFace’s tokenizer library for much of the pre/post-processing related to text-based models like GPT-2. For this tutorial, all we really need is a module that:

  1. Can instantiate a new tokenizer from a pre-trained one
  2. Encode text as a tensor
  3. Decode integers to text

We can implement our encoder like this:

defmodule Encoder do
  defstruct [:tokenizer]

  def new(model_id) do
    {:ok, tokenizer} = Tokenizers.Tokenizer.from_pretrained(model_id)
    %__MODULE__{tokenizer: tokenizer}
  end

  def encode(%{tokenizer: tokenizer}, text) do
    {:ok, encoding} = Tokenizers.Tokenizer.encode(tokenizer, text)

    Nx.tensor(Tokenizers.Encoding.get_ids(encoding))
    |> Nx.new_axis(0)
  end

  def decode(%{tokenizer: tokenizer}, id) do
    {:ok, token} = Tokenizers.Tokenizer.decode(tokenizer, [id])
    token
  end
end

Now, we can instantiate a new “encoder” using Encoder.new/1 and passing in a HuggingFace tokenizer id. This id maps to a HuggingFace model repository:

encoder = Encoder.new("gpt2")
%Encoder{
  tokenizer: #Tokenizers.Tokenizer<[
    vocab_size: 50257,
    byte_fallback: false,
    continuing_subword_prefix: "",
    dropout: nil,
    end_of_word_suffix: "",
    fuse_unk: false,
    model_type: "bpe",
    unk_token: nil
  ]>
}

Then, we can use this encoder to encode text as tensors:

Encoder.encode(encoder, "Hello world!")
#Nx.Tensor<
  s64[1][3]
  [
    [15496, 995, 0]
  ]
>

And decode integers to text:

Encoder.decode(encoder, 15496)
"Hello"

And that’s really all we need! Bumblebee also takes care of things like attention masks, padding inputs, etc. but those are outside the scope of this blog post.

Modeling Transformers with Nx

Now it’s time to model GPT-2 as Nx numerical definitions. Nx numerical definitions defn are just functions that support a special subset of the Elixir programming language. defn is meant to be a boundary for pure Nx code because everything within a defn has the potential to be JIT compiled. We can build our model up such that each “layer” or “block” maps to a numerical definition. Transformer models like GPT are relatively simple and consist of a few steps:

  1. Compute embedding from sequence (positional and token)
  2. A series of “transformer” blocks
  3. A final layer normalization and linear transformation

We can start by implementing our “top-level” model with this flow:

defmodule GPT do
  import Nx.Defn

  defn predict(input, wte, wpe, blocks, ln_f, opts \\ []) do
    opts = keyword!(opts, [n_head: 12])

    input
    |> embedding(wte, wpe)
    |> transformer(blocks, n_head: opts[:n_head])
    |> layer_norm(ln_f)
    |> Nx.dot([-1], wte["weight"], [-1])
  end
end

That’s really all our high-level model does. Of course, we still need to implement many of these layers. Let’s start with the embedding layer:

defn embedding(x, %{"weight" => wte}, %{"weight" => wpe}) do
  position_ids = Nx.iota({Nx.axis_size(x, 0), Nx.axis_size(x, 1)}, axis: -1)
  Nx.take(wte, x) + Nx.take(wpe, position_ids)
end

Our embedding layer consists of a token embedding and a positional embedding. The token embedding learns to map input tokens in our vocab to vectors. The operation is essentially a “lookup” of the embedding at the index given by the given integer ID in the input sequence. The positional embedding does exactly the same thing; however, the IDs map to the position in the input sequence. The idea is that this embedding encodes both semantic information from the token embedding and positional information from the positional embedding.

Now things get interesting—we need to implement our transformer. The transformer portion of the model consists of a number of transformer_blocks applied successively. In other words, the transformer implementation looks like this:

deftransform transformer(x, params, opts \\ []) do
  n_head = opts[:n_head]
  Enum.reduce(params, x, fn {_block_name, block_params}, x ->
    transformer_block(x, block_params, n_head: n_head)
  end)
end

This layer is implemented as a deftransform which is an escape hatch for using regular Elixir functions inside numerical definitions. Essentially, deftransform operates on Nx expressions—which means we can use things like Enum.reduce to lazily build up our operations. We need to use it here so we can apply transformer_block for each block in our input parameters to the input x.

Now we can implement transformer_block. The transformer_block consists of:

  1. Layer normalization
  2. Multi-head self-attention
  3. Residual
  4. Layer normalization
  5. Point-wise feed-forward network (FFN)

We can implement our transformer block then like:

  defn transformer_block(
    x,
    %{"mlp" => mlp, "attn" => attn, "ln_1" => ln_1, "ln_2" => ln_2},
    opts \\ []
  ) do
    opts = keyword!(opts, [n_head: 12])

      attention = 
        x
        |> layer_norm(ln_1)
        |> mha(attn, n_head: opts[:n_head])
      
      attention + ffn(layer_norm(attention, ln_2), mlp)
    end)
  end

And that’s pretty much it! Okay, so now we need to go about implementing the layers used here. We’ll start with mha or multi-head self-attention as that is the “meat” of the transformer. There is a lot of literature on what exactly attention is, so I will omit lengthy explanations and mathematical details. Attention is an operation that computes the relative importance of tokens in two sequences to one another.

Essentially, we compute a relationship matrix between two sequences. In transformers, we use something called self-attention to compute the relationship between an input sequence and itself. The “multi-head” terminology comes in as we split the embedded representation of our text into multiple heads to compute attention multiple times. It’s kind of like an ensembling technique. Rather than getting a single relationship (attention) matrix per transformer block, we get multiple.

Our mha implementation consists of the following steps:

  1. Compute a linear projection of the input
  2. Split the input into query, key, and value tensors
  3. Split q (query), k (key), and v (value) to use multiple heads
  4. Compute a causal_mask
  5. Perform attention operation
  6. Compute output projection

In code this looks like:

defn mha(x, %{"c_attn" => c_attn, "c_proj" => c_proj}, opts \\ []) do
  opts = keyword!(opts, [n_head: 12])
  x = linear(x, c_attn)

  {q, k, v} = split_qkv(x)
  q = split_heads(q, opts[:n_head])
  k = split_heads(k, opts[:n_head])
  v = split_heads(v, opts[:n_head])

  causal_mask = (1 - Nx.tri(Nx.axis_size(x, 0), Nx.axis_size(x, 0))) * -1.0e10
  out = attention(q, k, v, causal_mask)

  linear(out, c_proj)
end

From here, we need to implement the logic to split x into {q, k, v}, like this:

deftransformp split_qkv(tensor) do
  split_size = div(Nx.axis_size(tensor, -1), 3)
  q = tensor[[.., .., 0..(split_size - 1)]]
  k = tensor[[.., .., split_size..(2*split_size - 1)]]
  v = tensor[[.., .., 2*split_size..-1//1]]
  {q, k, v}
end

This essentially slices the input tensor into 3 distinct tensors that will be used for our attention operation. Next, we need to implement the logic for reshaping q, k, and v into multiple heads:

deftransformp split_heads(tensor, n_heads) do
  {batch, seq, _dim} = Nx.shape(tensor)
  Nx.reshape(tensor, {batch, seq, n_heads, :auto})    
end

Finally, we need to implement our actual attention operation:

defn attention(q, k, v, mask) do
  k = Nx.transpose(k, axes: [0, 2, 1, 3])
  q = Nx.transpose(q, axes: [0, 2, 1, 3])
  v = Nx.transpose(v, axes: [0, 2, 1, 3])
  
  q
  |> Nx.divide(Nx.sqrt(Nx.axis_size(q, -1)))
  |> Nx.dot([3], [0, 1], k, [3], [0, 1])
  |> softmax()
  |> Nx.add(mask)
  |> Nx.dot([3], [0, 1], v, [2], [0, 1])
  |> Nx.transpose(axes: [0, 2, 1, 3])
  |> flatten_heads()
end

The attention implementation mostly just consists of some shape manipulations and dot-products. There is a softmax operation which is used to normalize attention weights prior to computing the final attention output. We’ll implement that in a bit. First, let’s implement flatten_heads which combines our multiple attention heads into a single output:

deftransformp flatten_heads(tensor) do
  shape = Nx.shape(tensor)
  rank = Nx.rank(tensor)

  new_shape =
    shape
    |> Tuple.delete_at(rank - 1)
    |> put_elem(rank - 2, :auto)
  
  Nx.reshape(tensor, new_shape)
end

Now we need to go back and fill in some of our other missing pieces. We’ll start with ffn. ffn represents a basic feed-forward neural network (multi-layer perceptron):

defn ffn(x, %{"c_fc" => c_fc, "c_proj" => c_proj}) do
  x
  |> linear(c_fc)
  |> gelu()
  |> linear(c_proj)
end

This consists of a projection up with a “linear” or “dense” layer followed by a GeLU activation function and then a projection down with another linear layer. Next, we can implement our linear layer like this:

@doc """
Linear layer.

## Examples

    iex> {x, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {32, 128})
    iex> {w, key} = Nx.Random.uniform(key, shape: {128, 256})
    iex> {b, _key} = Nx.Random.uniform(key, shape: {256})
    iex> out = GPT.linear(x, %{"weight" => w, "bias" => b})
    iex> Nx.shape(out)
    {32, 256}
    iex> Nx.type(out)
    {:f, 32}
"""
defn linear(x, %{"weight" => w, "bias" => b}) do
  b + Nx.dot(x, w)
end

Note that I’ve added doctests to this implementation. When working in a Livebook, you can add these doctests and they will automatically run when the module is compiled! This will help you check your implementation as you move forward.

Our linear layer essentially just applies a linear transformation with an input weight and then adds a bias to the input.

The final “layer” we need to implement (besides our activation functions) is layer normalization. Layer normalization normalizes the input and then scales and shifts the normalized input according to learned parameters:

@doc """
Applies Layer Normalization.

## Examples

    iex> x = Nx.tensor([[2, 2, 3], [-5, 0, 1]])
    iex> actual = GPT.layer_norm(x, %{"weight" => Nx.broadcast(1.0, {2, 1}), "bias" => Nx.broadcast(0.0, {2, 1})})
    iex> expected = Nx.tensor([
    ...>   [-0.70709, -0.70709, 1.41418],
    ...>   [-1.397, 0.508, 0.889]
    ...> ])
    iex> Nx.all_close(actual, expected, atol: 1.0e-3)
    #Nx.Tensor<
      u8
      1
    >
"""
defn layer_norm(x, %{"weight" => w, "bias" => b}, opts \\ []) do
  opts = keyword!(opts, [eps: 1.0e-5])
  mean = Nx.mean(x, axes: [-1], keep_axes: true)
  std_dev = Nx.standard_deviation(x, axes: [-1] , keep_axes: true)
  x = (x - mean) / (std_dev + opts[:eps])
  w * x + b
end

With our layers implemented, we just need to implement the two “activation” functions used in our implementation and then our model is ready! Activation functions are point-wise (applied to each entry in a tensor) non-linear functions. We’ll start with GeLU:

@doc """
Applies GeLU Activation.

## Examples

    iex> actual = GPT.gelu(Nx.tensor([[1, 2], [-2, 0.5]]))
    iex> expected = Nx.tensor(([[0.84119, 1.9546], [-0.0454, 0.34571]]))
    iex> Nx.all_close(actual, expected, atol: 1.0e-3)
    #Nx.Tensor<
      u8
      1
    >
"""
defn gelu(x) do
  gaussian_const = Nx.sqrt(2 / Nx.Constants.pi())
  0.5 * x * (1 + Nx.tanh(gaussian_const * (x + 0.044715 * x ** 3)))
end

And then Softmax:

@doc """
Applies Softmax Activation.

## Examples

    iex> actual = GPT.softmax(Nx.tensor([[2, 100], [-5, 0]]))
    iex> expected = Nx.tensor([[2.74878501e-43, 1.0],[6.69285092e-03, 9.93307149e-01]])
    iex> Nx.all_close(actual, expected, atol: 1.0e-3)
    #Nx.Tensor<
      u8
      1
    >
"""
defn softmax(x) do
  exp_x = Nx.exp(x - Nx.reduce_max(x, axes: [-1], keep_axes: true))
  exp_x / Nx.sum(exp_x, axes: [-1], keep_axes: true)
end

And that’s it! Overall, your GPT module should look like this:

defmodule GPT do
  import Nx.Defn

  defn predict(input, wte, wpe, blocks, ln_f, opts \\ []) do
    opts = keyword!(opts, n_head: 12)

    input
    |> embedding(wte, wpe)
    |> transformer(blocks, n_head: opts[:n_head])
    |> layer_norm(ln_f)
    |> Nx.dot([-1], wte["weight"], [-1])
  end

  defn embedding(x, %{"weight" => wte}, %{"weight" => wpe}) do
    position_ids = Nx.iota({Nx.axis_size(x, 0), Nx.axis_size(x, 1)}, axis: -1)
    Nx.take(wte, x) + Nx.take(wpe, position_ids)
  end

  deftransform transformer(x, params, opts \\ []) do
    Enum.reduce(params, x, fn {_block_name, block_params}, x ->
      transformer_block(x, block_params, n_head: opts[:n_head])
    end)
  end

  defn transformer_block(
         x,
         %{"mlp" => mlp, "attn" => attn, "ln_1" => ln_1, "ln_2" => ln_2},
         opts \\ []
       ) do
    opts = keyword!(opts, n_head: 12)

    x
    |> layer_norm(ln_1)
    |> mha(attn, n_head: opts[:n_head])
    |> then(fn x ->
      x
      |> layer_norm(ln_2)
      |> ffn(mlp)
      |> Nx.add(x)
    end)
  end

  defn mha(x, %{"c_attn" => c_attn, "c_proj" => c_proj}, opts \\ []) do
    opts = keyword!(opts, n_head: 12)
    x = linear(x, c_attn)

    {q, k, v} = split_qkv(x)
    q = split_heads(q, opts[:n_head])
    k = split_heads(k, opts[:n_head])
    v = split_heads(v, opts[:n_head])

    causal_mask = (1 - Nx.tri(Nx.axis_size(x, 0), Nx.axis_size(x, 0))) * -1.0e10
    out = attention(q, k, v, causal_mask)

    linear(out, c_proj)
  end

  deftransformp split_qkv(tensor) do
    split_size = div(Nx.axis_size(tensor, -1), 3)
    q = tensor[[0..-1//1, 0..-1//1, 0..(split_size - 1)]]
    k = tensor[[0..-1//1, 0..-1//1, split_size..(2 * split_size - 1)]]
    v = tensor[[0..-1//1, 0..-1//1, (2 * split_size)..-1//1]]
    {q, k, v}
  end

  deftransformp split_heads(tensor, n_head) do
    {batch, seq, _dim} = Nx.shape(tensor)
    Nx.reshape(tensor, {batch, seq, n_head, :auto})
  end

  defn attention(q, k, v, mask) do
    k = Nx.transpose(k, axes: [0, 2, 1, 3])
    q = Nx.transpose(q, axes: [0, 2, 1, 3])
    v = Nx.transpose(v, axes: [0, 2, 1, 3])

    q
    |> Nx.divide(Nx.sqrt(Nx.axis_size(q, -1)))
    |> Nx.dot([3], [0, 1], k, [3], [0, 1])
    |> softmax()
    |> Nx.add(mask)
    |> Nx.dot([3], [0, 1], v, [2], [0, 1])
    |> Nx.transpose(axes: [0, 2, 1, 3])
    |> flatten_heads()
  end

  deftransformp flatten_heads(tensor) do
    shape = Nx.shape(tensor)
    rank = Nx.rank(tensor)

    new_shape =
      shape
      |> Tuple.delete_at(rank - 1)
      |> put_elem(rank - 2, :auto)

    Nx.reshape(tensor, new_shape)
  end

  defn ffn(x, %{"c_fc" => c_fc, "c_proj" => c_proj}) do
    x
    |> linear(c_fc)
    |> gelu()
    |> linear(c_proj)
  end

  @doc """
  Linear layer.

  ## Examples

      iex> {x, key} = Nx.Random.uniform(Nx.Random.key(42), shape: {32, 128})
      iex> {w, key} = Nx.Random.uniform(key, shape: {128, 256})
      iex> {b, _key} = Nx.Random.uniform(key, shape: {256})
      iex> out = GPT.linear(x, %{"weight" => w, "bias" => b})
      iex> Nx.shape(out)
      {32, 256}
      iex> Nx.type(out)
      {:f, 32}
  """
  defn linear(x, %{"weight" => w, "bias" => b}) do
    x |> Nx.dot(w) |> Nx.add(b)
  end

  @doc """
  Applies Layer Normalization.

  ## Examples

      iex> x = Nx.tensor([[2, 2, 3], [-5, 0, 1]])
      iex> actual = GPT.layer_norm(x, %{"weight" => Nx.broadcast(1.0, {2, 1}), "bias" => Nx.broadcast(0.0, {2, 1})})
      iex> expected = Nx.tensor([
      ...>   [-0.70709, -0.70709, 1.41418],
      ...>   [-1.397, 0.508, 0.889]
      ...> ])
      iex> Nx.all_close(actual, expected, atol: 1.0e-3)
      #Nx.Tensor<
        u8
        1
      >
  """
  defn layer_norm(x, %{"weight" => w, "bias" => b}, opts \\ []) do
    opts = keyword!(opts, eps: 1.0e-5)
    mean = Nx.mean(x, axes: [-1], keep_axes: true)
    variance = Nx.variance(x, axes: [-1], keep_axes: true)
    x = (x - mean) / Nx.sqrt(variance + opts[:eps])
    w * x + b
  end

  @doc """
  Applies GeLU Activation.

  ## Examples

      iex> actual = GPT.gelu(Nx.tensor([[1, 2], [-2, 0.5]]))
      iex> expected = Nx.tensor(([[0.84119, 1.9546], [-0.0454, 0.34571]]))
      iex> Nx.all_close(actual, expected, atol: 1.0e-3)
      #Nx.Tensor<
        u8
        1
      >
  """
  defn gelu(x) do
    0.5 * x * (1 + Nx.tanh(Nx.sqrt(2 / Nx.Constants.pi()) * (x + 0.044715 * Nx.pow(x, 3))))
  end

  @doc """
  Applies Softmax Activation.

  ## Examples

      iex> actual = GPT.softmax(Nx.tensor([[2, 100], [-5, 0]]))
      iex> expected = Nx.tensor([[2.74878501e-43, 1.0],[6.69285092e-03, 9.93307149e-01]])
      iex> Nx.all_close(actual, expected, atol: 1.0e-3)
      #Nx.Tensor<
        u8
        1
      >
  """
  defn softmax(x) do
    exp_x = Nx.exp(x - Nx.reduce_max(x, axes: [-1], keep_axes: true))
    exp_x / Nx.sum(exp_x, axes: [-1], keep_axes: true)
  end
end
4 doctests, 0 failures
{:module, GPT, <<70, 79, 82, 49, 0, 0, 59, ...>>, true}

With our model implemented, we can create a predict function which takes input parameters and produces an output tensor:

predict_fun = fn input, params ->
  {wte, params} = Map.pop!(params, "wte")
  {wpe, params} = Map.pop!(params, "wpe")
  {ln_f, params} = Map.pop!(params, "ln_f")

  GPT.predict(input, wte, wpe, params, ln_f)
end
#Function<41.3316493/2 in :erl_eval.expr/6>

And of course, we’ll want to JIT compile our predict function so it runs accelerated:

predict_fun = Nx.Defn.jit(predict_fun, compiler: EXLA)
#Function<134.64864510/2 in Nx.Defn.Compiler.fun/2>

Now we can get some input and pass it to our model!

input = Encoder.encode(encoder, "Hello World!")
predict_fun.(input, blocks)
#Nx.Tensor<
  f32[1][3][50257]
  EXLA.Backend<host:0, 0.879302795.2891055124.147515>
  [
    [
      [-17.95338249206543, -10.991355895996094, -13.45356273651123, -19.010509490966797, -22.15827751159668, -19.9477596282959, -14.55197525024414, -11.953147888183594, -12.225095748901367, -22.248506546020508, -14.363884925842285, -9.260599136352539, -10.11651611328125, -12.922428131103516, -14.436551094055176, -12.716976165771484, -10.315630912780762, -14.780750274658203, -10.36999797821045, -11.931602478027344, -15.256696701049805, -12.662351608276367, -13.149083137512207, -13.825837135314941, -18.168048858642578, -11.906233787536621, -18.675334930419922, -15.466197967529297, -15.000768661499023, -15.671730041503906, -15.362894058227539, -10.682412147521973, -13.200881958007812, -15.150506019592285, -14.514249801635742, -9.522102355957031, -10.189918518066406, -14.229316711425781, -13.961867332458496, -13.763506889343262, -10.999066352844238, -14.516007423400879, -14.772758483886719, -12.857023239135742, -13.77109146118164, -10.552913665771484, -12.120234489440918, -18.59151840209961, -13.763373374938965, -12.64120101928711, ...],
      ...
    ]
  ]
>

And that’s it! You just implemented GPT-2 in pure Nx! Now we can use this model to generate some text.

Generating Text

The output for the GPT-2 model seems like just a random tensor; however, we can use it to generate text. You’ll notice the final dimension of the tensor has a size of 50257. This actually maps to the exact size of the GPT-2 vocabulary. We can turn our model output into a next token prediction with the following computation:

output = predict_fun.(input, blocks)
logits = output[[.., -1]]
next_token = Nx.argmax(logits, axis: -1)
#Nx.Tensor<
  s64[1]
  EXLA.Backend<host:0, 0.879302795.2891055124.147519>
  [84]
>

This essentially grabs the next token logits from the output tensor and then computes the argmax of the logits tensor which represents the ID of the next token. We can continuously add this to our input sequence to repeatedly get next token predictions from our model. This is pretty easily modeled with Enum.reduce_while:

defmodule Generator do
  def generate(predict_fun, encoder, input, params, eos_id, max_seq_len) do
    encoded_input = Encoder.encode(encoder, input)
    seq_len = Nx.axis_size(encoded_input, 1)

    Enum.reduce_while(seq_len..max_seq_len, encoded_input, fn _idx, current_input ->
      output = predict_fun.(current_input, params)
      logits = output[[.., -1]]
      next_token = Nx.argmax(logits, axis: -1, keep_axis: true)

      if eos_id == Nx.to_number(Nx.squeeze(next_token)) do
        {:halt, current_input}
      else
        IO.write("#{Encoder.decode(encoder, Nx.to_number(Nx.squeeze(next_token)))}")
        new_sequence = Nx.concatenate([current_input, next_token], axis: -1)
        {:cont, new_sequence}
      end
    end)
  end
end
{:module, Generator, <<70, 79, 82, 49, 0, 0, 12, ...>>, {:generate, 6}}

This generation function continuously predicts the next token and creates a new sequence from it. It will stop predicting if the model outputs its eos or end-of-sequence token OR if it reaches the specified max sequence length. For this example, we decode tokens at each step just to inspect what our model is doing:

Generator.generate(predict_fun, encoder, "Elixir is", blocks, 50256, 256)

Now, if you know a thing or two about Nx and JIT compilation, you’ll know this implementation is inefficient because we have to compile a new computation at every step because the input shape changes. This is meant to be a simple tutorial—Bumblebee uses more complex implementations that do not lead to recompilations, support streaming, and a bunch of other things: https://github.com/elixir-nx/bumblebee/blob/main/lib/bumblebee/text/generation.ex

Servings and Inference

Perhaps the best thing Bumblebee has to offer is a collection of pre-defined servings. Servings encapsulate machine learning tasks and can be easily added to your application’s supervision tree and used anywhere in your application for scalable machine learning inference. Servings are really just datastructures that encapsulate preprocessing, inference, and postprocessing. A serving for this model might look like:

serving =
  Nx.Serving.new(fn _, _ -> &predict_fun.(&1, blocks) end)
  |> Nx.Serving.client_preprocessing(fn input ->
    {Nx.Batch.concatenate([Encoder.encode(encoder, input)]), :ok}
  end)
  |> Nx.Serving.client_postprocessing(fn {token_ids, :server_info}, _meta ->
    token_ids[[.., -1]]
    |> Nx.argmax(axis: -1)
    |> Nx.squeeze()
    |> Nx.to_number()
    |> then(&Encoder.decode(encoder, &1))
  end)
%Nx.Serving{
  module: Nx.Serving.Default,
  arg: #Function<41.3316493/2 in :erl_eval.expr/6>,
  client_preprocessing: #Function<42.3316493/1 in :erl_eval.expr/6>,
  client_postprocessing: #Function<41.3316493/2 in :erl_eval.expr/6>,
  streaming: nil,
  batch_size: nil,
  distributed_postprocessing: &Function.identity/1,
  process_options: [],
  defn_options: []
}
Nx.Serving.run(serving, "Hello world")
"о"

Nx.Serving is a powerful abstraction that supports load balancing and distribution by default. It also supports dynamic batch inference. The process of creating a serving that handles dynamic batching of requests for you is as simple as:

Supervisor.start_link(
  [
    {Nx.Serving, name: GPTInference, serving: serving}
  ],
  strategy: :one_for_one
)
{:ok, #PID<0.1036.0>}

And then you can get inferences from the named process:

Nx.Serving.batched_run(GPTInference, "Hello!")
" too"

You can use it to transform your simple Nx functions into scalable machine learning deployments embedded directly in Phoenix applications. I highly suggest reading more about it here: https://hexdocs.pm/nx/Nx.Serving.html

Conclusion

And that’s it! In this post, you implemented GPT-2 from scratch using just Nx. Hopefully, this gives you some intuition (and appreciation) of what Bumblebee is doing under the hood. Until next time!

Elixir can be the game-changer you need to put your digital product ahead of the competition. Contact us today to learn how we can put it to work for you.

Newsletter

Stay in the Know

Get the latest news and insights on Elixir, Phoenix, machine learning, product strategy, and more—delivered straight to your inbox.

Narwin holding a press release sheet while opening the DockYard brand kit box