Implementing a neural Part-of-Speech tagger

by Jonathan K. Kummerfeld [site]

DyNet, PyTorch and Tensorflow are complex frameworks with different ways of approaching neural network implementation and variations in default behaviour. This page is intended to show how to implement the same non-trivial model in all three. The design of the page is motivated by my own preference for a complete program with annotations, rather than the more common tutorial style of introducing code piecemeal in between discussion. The design of the code is also geared towards providing a complete picture of how things fit together. For a non-tutorial version of this code it would be better to use abstraction to improve flexibility, but that would have complicated the flow here.

Model: The three implementations below all define a part-of-speech tagger with word embeddings initialised using GloVe, fed into a one-layer bidirectional LSTM, followed by a matrix multiplication to produce scores for tags. They all score ~97.2% on the development set of the Penn Treebank. The specific hyperparameter choices follows Yang, Liang, and Zhang (CoLing 2018) and matches their performance for the setting without a CRF layer or character-based word embeddings. The repository for this page provides the code in runnable form. The only dependencies are the respective frameworks (DyNet 2.0.3, PyTorch 0.4.1 and Tensorflow 1.9.0).

Website usage: Use the buttons to show one or more implementations and their associated comments (note, depending on your screen size you may need to scroll to see all the code). Matching or closely related content is aligned. Framework-specific comments are highlighted in a colour that matches their button and a line is used to make the link from the comment to the code clear.

New (2019) Runnable Version: I have made a slightly modified version of the Tensorflow code available as a Google Colaboratory Notebook.

Making this helped me understand all three frameworks better. Hopefully you will find it informative too!

We use argparse for processing command line arguments, random for shuffling our data, sys for flushing output, and numpy for handling vectors of data.

# DyNet Implementation
import argparse
import random
import sys

import numpy as np

# PyTorch Implementation
import argparse
import random
import sys

import numpy as np

# Tensorflow Implementation
import argparse
import random
import sys

import numpy as np

Typically, we would make many of these constants command line arguments and tune using the development set. For simplicity, I have fixed their values here to match Jiang, Liang and Zhang (CoLing 2018).

PAD = "__PAD__"
UNK = "__UNK__"
DIM_EMBEDDING = 100 
LSTM_HIDDEN = 100 
BATCH_SIZE = 10 
LEARNING_RATE = 0.015 
LEARNING_DECAY_RATE = 0.05 
EPOCHS = 100 
KEEP_PROB = 0.5 
GLOVE = "../data/glove.6B.100d.txt" 
WEIGHT_DECAY = 1e-8

PAD = "__PAD__"
UNK = "__UNK__"
DIM_EMBEDDING = 100 
LSTM_HIDDEN = 100 
BATCH_SIZE = 10 
LEARNING_RATE = 0.015 
LEARNING_DECAY_RATE = 0.05 
EPOCHS = 100 
KEEP_PROB = 0.5 
GLOVE = "../data/glove.6B.100d.txt" 
WEIGHT_DECAY = 1e-8

PAD = "__PAD__"
UNK = "__UNK__"
DIM_EMBEDDING = 100 
LSTM_HIDDEN = 100 
BATCH_SIZE = 10 
LEARNING_RATE = 0.015 
LEARNING_DECAY_RATE = 0.05 
EPOCHS = 100 
KEEP_PROB = 0.5 
GLOVE = "../data/glove.6B.100d.txt" 
# WEIGHT_DECAY = 1e-8 Not used, see note at the bottom of the page

Dynet library imports. The first allows us to configure DyNet from within code rather than on the command line: mem is the amount of system memory initially allocated (DyNet has its own memory management), autobatch toggles automatic parallelisation of computations, weight_decay rescales weights by (1 - decay) after every update, random_seed sets the seed for random number generation.

import dynet_config
dynet_config.set(mem=256, autobatch=0, weight_decay=WEIGHT_DECAY,random_seed=0)
# dynet_config.set_gpu() for when we want to run with GPUs
import dynet as dy

  
PyTorch library import.

 
import torch
torch.manual_seed(0)

 
Tensorflow library import.

  
import tensorflow as tf

 
# Data reading
def read_data(filename):
# Data reading
def read_data(filename):
# Data reading
def read_data(filename):
We are expecting a minor variation on the raw Penn Treebank data, with one line per sentence, tokens separated by spaces, and the tag for each token placed next to its word (the | works as a separator as it does not appear as a token).

    """Example input:
    Pierre|NNP Vinken|NNP ,|, 61|CD years|NNS old|JJ
    """
    content = []
    with open(filename) as data_src:
        for line in data_src:
            t_p = [w.split("|") for w in line.strip().split()]
            tokens = [v[0] for v in t_p]
            tags = [v[1] for v in t_p]
            content.append((tokens, tags))
    return content

def simplify_token(token):
    chars = []
    for char in token:
    """Example input:
    Pierre|NNP Vinken|NNP ,|, 61|CD years|NNS old|JJ
    """
    content = []
    with open(filename) as data_src:
        for line in data_src:
            t_p = [w.split("|") for w in line.strip().split()]
            tokens = [v[0] for v in t_p]
            tags = [v[1] for v in t_p]
            content.append((tokens, tags))
    return content

def simplify_token(token):
    chars = []
    for char in token:
    """Example input:
    Pierre|NNP Vinken|NNP ,|, 61|CD years|NNS old|JJ
    """
    content = []
    with open(filename) as data_src:
        for line in data_src:
            t_p = [w.split("|") for w in line.strip().split()]
            tokens = [v[0] for v in t_p]
            tags = [v[1] for v in t_p]
            content.append((tokens, tags))
    return content

def simplify_token(token):
    chars = []
    for char in token:
Reduce sparsity by replacing all digits with 0.

        if char.isdigit():
            chars.append("0")
        else:
            chars.append(char)
    return ''.join(chars)

def main():
        if char.isdigit():
            chars.append("0")
        else:
            chars.append(char)
    return ''.join(chars)

def main():
        if char.isdigit():
            chars.append("0")
        else:
            chars.append(char)
    return ''.join(chars)

def main():
For the purpose of this example we only have arguments for locations of the data.

    parser = argparse.ArgumentParser(description='POS tagger.')
    parser.add_argument('training_data')
    parser.add_argument('dev_data')
    args = parser.parse_args()

    train = read_data(args.training_data)
    dev = read_data(args.dev_data)

    parser = argparse.ArgumentParser(description='POS tagger.')
    parser.add_argument('training_data')
    parser.add_argument('dev_data')
    args = parser.parse_args()

    train = read_data(args.training_data)
    dev = read_data(args.dev_data)

    parser = argparse.ArgumentParser(description='POS tagger.')
    parser.add_argument('training_data')
    parser.add_argument('dev_data')
    args = parser.parse_args()

    train = read_data(args.training_data)
    dev = read_data(args.dev_data)

These indices map from strings to integers, which we apply to the input for our model. UNK is added to our mapping so that there is a vector we can use when we encounter unknown words. The special PAD symbol is used in PyTorch and Tensorflow as part of shaping the data in a batch to be a consistent size. It is not needed for DyNet, but kept for consistency.

    # Make indices
    id_to_token = [PAD, UNK]
    token_to_id = {PAD: 0, UNK: 1}
    id_to_tag = [PAD]
    tag_to_id = {PAD: 0}
    # Make indices
    id_to_token = [PAD, UNK]
    token_to_id = {PAD: 0, UNK: 1}
    id_to_tag = [PAD]
    tag_to_id = {PAD: 0}
    # Make indices
    id_to_token = [PAD, UNK]
    token_to_id = {PAD: 0, UNK: 1}
    id_to_tag = [PAD]
    tag_to_id = {PAD: 0}
The '+ dev' may seem like an error, but is done here for convenience. It means in the next section we will retain the GloVe embeddings that appear in dev but not train. They won't be updated during training, so it does not mean we are getting information we shouldn't. In practise I would simply keep all the GloVe embeddings to avoid any potential incorrect use of the evaluation data.

    for tokens, tags in train + dev:
        for token in tokens:
            token = simplify_token(token)
            if token not in token_to_id:
                token_to_id[token] = len(token_to_id)
                id_to_token.append(token)
        for tag in tags:
            if tag not in tag_to_id:
                tag_to_id[tag] = len(tag_to_id)
                id_to_tag.append(tag)
    NWORDS = len(token_to_id)
    NTAGS = len(tag_to_id)

    # Load pre-trained GloVe vectors
    for tokens, tags in train + dev:
        for token in tokens:
            token = simplify_token(token)
            if token not in token_to_id:
                token_to_id[token] = len(token_to_id)
                id_to_token.append(token)
        for tag in tags:
            if tag not in tag_to_id:
                tag_to_id[tag] = len(tag_to_id)
                id_to_tag.append(tag)
    NWORDS = len(token_to_id)
    NTAGS = len(tag_to_id)

    # Load pre-trained GloVe vectors
    for tokens, tags in train + dev:
        for token in tokens:
            token = simplify_token(token)
            if token not in token_to_id:
                token_to_id[token] = len(token_to_id)
                id_to_token.append(token)
        for tag in tags:
            if tag not in tag_to_id:
                tag_to_id[tag] = len(tag_to_id)
                id_to_tag.append(tag)
    NWORDS = len(token_to_id)
    NTAGS = len(tag_to_id)

    # Load pre-trained GloVe vectors
I am assuming these are 100-dimensional GloVe embeddings in their standard format.

    pretrained = {}
    for line in open(GLOVE):
        parts = line.strip().split()
        word = parts[0]
        vector = [float(v) for v in parts[1:]]
        pretrained[word] = vector
    pretrained = {}
    for line in open(GLOVE):
        parts = line.strip().split()
        word = parts[0]
        vector = [float(v) for v in parts[1:]]
        pretrained[word] = vector
    pretrained = {}
    for line in open(GLOVE):
        parts = line.strip().split()
        word = parts[0]
        vector = [float(v) for v in parts[1:]]
        pretrained[word] = vector
We need the word vectors as a list to initialise the embeddings. Each entry in the list corresponds to the token with that index.

    pretrained_list = []
    scale = np.sqrt(3.0 / DIM_EMBEDDING)
    for word in id_to_token:
        # apply lower() because all GloVe vectors are for lowercase words
        if word.lower() in pretrained:
            pretrained_list.append(np.array(pretrained[word.lower()]))
        else:
    pretrained_list = []
    scale = np.sqrt(3.0 / DIM_EMBEDDING)
    for word in id_to_token:
        # apply lower() because all GloVe vectors are for lowercase words
        if word.lower() in pretrained:
            pretrained_list.append(np.array(pretrained[word.lower()]))
        else:
    pretrained_list = []
    scale = np.sqrt(3.0 / DIM_EMBEDDING)
    for word in id_to_token:
        # apply lower() because all GloVe vectors are for lowercase words
        if word.lower() in pretrained:
            pretrained_list.append(np.array(pretrained[word.lower()]))
        else:
For words that do not appear in GloVe we generate a random vector (note, the choice of scale here is important and we follow Jiang, Liang and Zhang (CoLing 2018).

            random_vector = np.random.uniform(-scale, scale, [DIM_EMBEDDING])
            pretrained_list.append(random_vector)

            random_vector = np.random.uniform(-scale, scale, [DIM_EMBEDDING])
            pretrained_list.append(random_vector)

            random_vector = np.random.uniform(-scale, scale, [DIM_EMBEDDING])
            pretrained_list.append(random_vector)

The most significant difference between the frameworks is how the model parameters and their execution is defined. In DyNet we define parameters here and then define computation as needed. In PyTorch we use a class with the parameters defined in the constructor and the computation defined in the forward() method. In Tensorflow we define both parameters and computation here.

    # Model creation
    # Model creation
    # Model creation
 
    model = dy.ParameterCollection()
    # Create word embeddings and initialise
  
Lookup parameters are a matrix that supports efficient sparse lookup.

    pEmbedding = model.add_lookup_parameters((NWORDS, DIM_EMBEDDING))
    pEmbedding.init_from_array(np.array(pretrained_list))
    # Create LSTM parameters
  
Objects that create LSTM cells and the necessary parameters.

    stdv = 1.0 / np.sqrt(LSTM_HIDDEN) 
    f_lstm = dy.VanillaLSTMBuilder(1, DIM_EMBEDDING, LSTM_HIDDEN, model,
            forget_bias=(np.random.random_sample() - 0.5) * 2 * stdv)
    b_lstm = dy.VanillaLSTMBuilder(1, DIM_EMBEDDING, LSTM_HIDDEN, model,
            forget_bias=(np.random.random_sample() - 0.5) * 2 * stdv)
    # Create output layer
    pOutput = model.add_parameters((NTAGS, 2 * LSTM_HIDDEN))
    
    # Set recurrent dropout values (not used in this case)
    f_lstm.set_dropouts(0.0, 0.0)
    b_lstm.set_dropouts(0.0, 0.0)
    # Initialise LSTM parameters
  
To match PyTorch, we initialise the parameters with an unconventional approach.

    f_lstm.get_parameters()[0][0].set_value(
            np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, DIM_EMBEDDING]))
    f_lstm.get_parameters()[0][1].set_value(
            np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, LSTM_HIDDEN]))
    f_lstm.get_parameters()[0][2].set_value(
            np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN]))
    b_lstm.get_parameters()[0][0].set_value(
            np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, DIM_EMBEDDING]))
    b_lstm.get_parameters()[0][1].set_value(
            np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN, LSTM_HIDDEN]))
    b_lstm.get_parameters()[0][2].set_value(
            np.random.uniform(-stdv, stdv, [4 * LSTM_HIDDEN]))

  
The trainer object is used to update the model.

    # Create the trainer
    trainer = dy.SimpleSGDTrainer(model, learning_rate=LEARNING_RATE)
  
DyNet clips gradients by default, which we disable here (this can have a big impact on performance).

    trainer.set_clip_threshold(-1)

  
  
    model = TaggerModel(NWORDS, NTAGS, pretrained_list, id_to_token)
    # Create optimizer and configure the learning rate
    optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY)
 
The learning rate for each epoch is set by multiplying the initial rate by the factor produced by this function.

 
    rescale_lr = lambda epoch: 1 / (1 + LEARNING_DECAY_RATE * epoch)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
            lr_lambda=rescale_lr)

 

This line creates a new graph and makes it the default graph for operations to be registered to. It is not necessary here because we only have one graph, but is considered good practise (more discussion on Stackoverflow.

  
    with tf.Graph().as_default():
Placeholders are inputs/values that will be fed into the network each time it is run. We define their type, name, and shape (constant, 1D vector, 2D vector, etc). This includes what we normally think of as inputs (e.g. the tokens) as well as parameters we want to change at run time (e.g. the learning rate).

  
        # Define inputs
        e_input = tf.placeholder(tf.int32, [None, None], name='input')
        e_lengths = tf.placeholder(tf.int32, [None], name='lengths')
        e_mask = tf.placeholder(tf.int32, [None, None], name='mask')
        e_gold_output = tf.placeholder(tf.int32, [None, None],
                name='gold_output')
        e_keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        e_learning_rate = tf.placeholder(tf.float32, name='learning_rate')

        # Define word embedding
The embedding matrix is a variable (so they can shift in training), initialized with the vectors defined above.

  
        glove_init = tf.constant_initializer(np.array(pretrained_list))
        e_embedding = tf.get_variable("embedding", [NWORDS, DIM_EMBEDDING],
                initializer=glove_init)
        e_embed = tf.nn.embedding_lookup(e_embedding, e_input)

        # Define LSTM cells
We create an LSTM cell, then wrap it in a class that applies dropout to the input and output.

  
        e_cell_f = tf.contrib.rnn.BasicLSTMCell(LSTM_HIDDEN)
        e_cell_f = tf.contrib.rnn.DropoutWrapper(e_cell_f,
                input_keep_prob=e_keep_prob, output_keep_prob=e_keep_prob)
        # Recurrent dropout options
We are not using recurrent dropout, but it is a common enough feature of networks that it's good to see how it is done.

  
        #        variational_recurrent=True, dtype=tf.float32,
        #        input_size=DIM_EMBEDDING)
Similarly, multi-layer networks are a common use case. In Tensorflow, we would wrap a list of cells with a MultiRNNCell.

  
        # Multi-layer cell creation
        # e_cell_f = tf.contrib.rnn.MultiRNNCell([e_cell_f])
We are making a bidirectional network, so we need another cell for the reverse direction.

  
        e_cell_b = tf.contrib.rnn.BasicLSTMCell(LSTM_HIDDEN)
        e_cell_b = tf.contrib.rnn.DropoutWrapper(e_cell_b,
                input_keep_prob=e_keep_prob, output_keep_prob=e_keep_prob)
To use the cells we create a dynamic RNN. The 'dynamic' aspect means we can feed in the lengths of input sequences not counting padding and it will stop early.

  
        e_initial_state_f = e_cell_f.zero_state(BATCH_SIZE, dtype=tf.float32)
        e_initial_state_b = e_cell_f.zero_state(BATCH_SIZE, dtype=tf.float32)
        e_lstm_outputs, e_final_state = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=e_cell_f, cell_bw=e_cell_b, inputs=e_embed,
                initial_state_fw=e_initial_state_f,
                initial_state_bw=e_initial_state_b,
                sequence_length=e_lengths, dtype=tf.float32)
        e_lstm_outputs_merged = tf.concat(e_lstm_outputs, 2)

        # Define output layer
Matrix multiply to get scores for each class.

  
        e_predictions = tf.contrib.layers.fully_connected(e_lstm_outputs_merged,
                NTAGS, activation_fn=None)
        # Define loss and update
Cross-entropy loss. The reduction flag is crucial (the default is to average over the sequence). The weights flag accounts for padding that makes all of the sequences the same length.

  
        e_loss = tf.losses.sparse_softmax_cross_entropy(e_gold_output,
                e_predictions, weights=e_mask,
                reduction=tf.losses.Reduction.SUM)
        e_train = tf.train.GradientDescentOptimizer(e_learning_rate).minimize(e_loss)
        # Update with gradient clipping
If we wanted to do gradient clipping we would need to do the update in a few steps, first calculating the gradient, then modifying it before applying it.

  
        # e_optimiser = tf.train.GradientDescentOptimizer(LEARNING_RATE)
        # e_gradients = e_optimiser.compute_gradients(e_loss)
        # e_clipped_gradients = [(tf.clip_by_value(grad, -5., 5.), var)
        #         for grad, var in e_gradients]
        # e_train = e_optimiser.apply_gradients(e_gradients)

        # Define output
        e_auto_output = tf.argmax(e_predictions, 2, output_type=tf.int32)

        # Do training
Configure the system environment. By default Tensorflow uses all available GPUs and RAM. These lines limit the number of GPUs used and the amount of RAM. To limit which GPUs are used, set the environment variable CUDA_VISIBLE_DEVICES (e.g. "export CUDA_VISIBLE_DEVICES=0,1").

  
        config = tf.ConfigProto(
            device_count = {'GPU': 0},
            gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.8)
        )
A session runs the graph. We use a 'with' block to ensure it is closed, which frees various resources.

  
        with tf.Session(config=config) as sess:
Run executes operations, in this case initializing the variables.

  
            sess.run(tf.global_variables_initializer())

To make the code match across the three versions, we group together some framework specific values needed when doing a pass over the data.

    expressions = (pEmbedding, pOutput, f_lstm, b_lstm, trainer)
    expressions = (model, optimizer)
            expressions = [
                e_auto_output, e_gold_output, e_input, e_keep_prob, e_lengths,
                e_loss, e_train, e_mask, e_learning_rate, sess
            ]
Main training loop, in which we shuffle the data, set the learning rate, do one complete pass over the training data, then evaluate on the development data.

    for epoch in range(EPOCHS):
        random.shuffle(train)

    for epoch in range(EPOCHS):
        random.shuffle(train)

            for epoch in range(EPOCHS):
                random.shuffle(train)

 
        # Update learning rate
        trainer.learning_rate = LEARNING_RATE / (1+ LEARNING_DECAY_RATE * epoch)

  
  
        # Update learning rate
 
First call to rescale_lr is with a 0, which is why this must be done before the pass over the data.

 
        scheduler.step()

 
Training mode (and evaluation mode below) do things like enable dropout components.

 
        model.train() 
        model.zero_grad()
 
   
                # Determine the current learning rate
                current_lr = LEARNING_RATE / (1+ LEARNING_DECAY_RATE * epoch)

Training pass.

        loss, tacc = do_pass(train, token_to_id, tag_to_id, expressions, True)
        loss, tacc = do_pass(train, token_to_id, tag_to_id, expressions,
                True)

                loss, tacc = do_pass(train, token_to_id, tag_to_id, expressions,
                        True, current_lr)
  
        model.eval()
 
Dev pass.

        _, dacc = do_pass(dev, token_to_id, tag_to_id, expressions, False)
        print("{} loss {} t-acc {} d-acc {}".format(epoch, loss, tacc, dacc))

        _, dacc = do_pass(dev, token_to_id, tag_to_id, expressions, False)
        print("{} loss {} t-acc {} d-acc {}".format(epoch, loss,
            tacc, dacc))

                _, dacc = do_pass(dev, token_to_id, tag_to_id, expressions,
                        False)
                print("{} loss {} t-acc {} d-acc {}".format(epoch, loss, tacc,
                    dacc))

The syntax varies, but in all three cases either saving or loading the parameters of a model must be done after the model is defined.

    # Save model
    model.save("tagger.dy.model")

    # Load model
    model.populate("tagger.dy.model")

    # Evaluation pass.
    _, test_acc = do_pass(dev, token_to_id, tag_to_id, expressions, False)
    print("Test Accuracy: {:.3f}".format(test_acc))

    # Save model
    torch.save(model.state_dict(), "tagger.pt.model")

    # Load model
    model.load_state_dict(torch.load('tagger.pt.model'))

    # Evaluation pass.
    _, test_acc = do_pass(dev, token_to_id, tag_to_id, expressions, False)
    print("Test Accuracy: {:.3f}".format(test_acc))

            # Save model
            saver = tf.train.Saver()
            saver.save(sess, "./tagger.tf.model")

            # Load model
            saver.restore(sess, "./tagger.tf.model")

            # Evaluation pass.
            _, test_acc = do_pass(dev, token_to_id, tag_to_id, expressions,
                    False)
            print("Test Accuracy: {:.3f}".format(test_acc))

Neural network definition code. In PyTorch networks are defined using classes that extend Module.

 
class TaggerModel(torch.nn.Module):
 
In the constructor we define objects that will do each of the computations.

 
    def __init__(self, nwords, ntags, pretrained_list, id_to_token):
        super().__init__()

        # Create word embeddings
        pretrained_tensor = torch.FloatTensor(pretrained_list)
        self.word_embedding = torch.nn.Embedding.from_pretrained(
                pretrained_tensor, freeze=False)
        # Create input dropout parameter
        self.word_dropout = torch.nn.Dropout(1 - KEEP_PROB)
        # Create LSTM parameters
        self.lstm = torch.nn.LSTM(DIM_EMBEDDING, LSTM_HIDDEN, num_layers=1,
                batch_first=True, bidirectional=True)
        # Create output dropout parameter
        self.lstm_output_dropout = torch.nn.Dropout(1 - KEEP_PROB)
        # Create final matrix multiply parameters
        self.hidden_to_tag = torch.nn.Linear(LSTM_HIDDEN * 2, ntags)

    def forward(self, sentences, labels, lengths, cur_batch_size):
        max_length = sentences.size(1)

        # Look up word vectors
        word_vectors = self.word_embedding(sentences)
        # Apply dropout
        dropped_word_vectors = self.word_dropout(word_vectors)
        # Run the LSTM over the input, reshaping data for efficiency
 
Assuming the data is ordered longest to shortest, this provides a view of the data that fits with how cuDNN works.

 
        packed_words = torch.nn.utils.rnn.pack_padded_sequence(
                dropped_word_vectors, lengths, True)
 
The None argument is an optional initial hidden state (default is a zero vector). The ignored return value contains the hidden states.

 
        lstm_out, _ = self.lstm(packed_words, None)
 
Reverse the view shift made for cuDNN. Specifying total_length is not necessary in general (it can be inferred), but is necessary for parallel processing. The ignored return value contains the length of each sequence.

 
        lstm_out, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_out,
                batch_first=True, total_length=max_length)
        # Apply dropout
        lstm_out_dropped = self.lstm_output_dropout(lstm_out)
        # Matrix multiply to get scores for each tag
        output_scores = self.hidden_to_tag(lstm_out_dropped)

        # Calculate loss and predictions
 
We reshape to [batch size * sequence length , ntags] for more efficient processing.

 
        output_scores = output_scores.view(cur_batch_size * max_length, -1)
        flat_labels = labels.view(cur_batch_size * max_length)
 
The ignore index refers to outputs to not score, which we use to ignore padding. 'reduction' defines how to combine the losses at each point in the sequence. The default is elementwise_mean, which would not do what we want.

 
        loss_function = torch.nn.CrossEntropyLoss(ignore_index=0, reduction='sum')
        loss = loss_function(output_scores, flat_labels)
        predicted_tags  = torch.argmax(output_scores, 1)
 
Reshape to have dimensions [batch size , sequence length].

 
        predicted_tags = predicted_tags.view(cur_batch_size, max_length)
        return loss, predicted_tags

 
Inference (the same function for train and test).

def do_pass(data, token_to_id, tag_to_id, expressions, train):
    pEmbedding, pOutput, f_lstm, b_lstm, trainer = expressions


    # Loop over batches
    loss = 0
    match = 0
    total = 0
    for start in range(0, len(data), BATCH_SIZE):
def do_pass(data, token_to_id, tag_to_id, expressions, train):
    model, optimizer = expressions


    # Loop over batches
    loss = 0
    match = 0
    total = 0
    for start in range(0, len(data), BATCH_SIZE):
def do_pass(data, token_to_id, tag_to_id, expressions, train, lr=0.0):
    e_auto_output, e_gold_output, e_input, e_keep_prob, e_lengths, e_loss, \
            e_train, e_mask, e_learning_rate, session = expressions

    # Loop over batches
    loss = 0
    match = 0
    total = 0
    for start in range(0, len(data), BATCH_SIZE):
Form the batch and order it based on length (important for efficient processing in PyTorch).

        batch = data[start : start + BATCH_SIZE]
        batch.sort(key = lambda x: -len(x[0]))
        batch = data[start : start + BATCH_SIZE]
        batch.sort(key = lambda x: -len(x[0]))
        batch = data[start : start + BATCH_SIZE]
        batch.sort(key = lambda x: -len(x[0]))
Log partial results so we can conveniently check progress.

        if start % 4000 == 0 and start > 0:
            print(loss, match / total)
            sys.stdout.flush()

        if start % 4000 == 0 and start > 0:
            print(loss, match / total)
            sys.stdout.flush()

        if start % 4000 == 0 and start > 0:
            print(loss, match / total)
            sys.stdout.flush()

Start a new computation graph for this batch.

        # Process batch
        dy.renew_cg()
  
For each example, we will construct an expression that gives the loss.

        loss_expressions = []
        predicted = []
  
  
        # Prepare inputs
 
Prepare input arrays, using .long() to cast the type from Tensor to LongTensor.

 
        cur_batch_size = len(batch)
        max_length = len(batch[0][0])
        lengths = [len(v[0]) for v in batch]
        input_array = torch.zeros((cur_batch_size, max_length)).long()
        output_array = torch.zeros((cur_batch_size, max_length)).long()
 
   
        # Add empty sentences to fill the batch
We add empty sentences because Tensorflow requires every batch to be the same size.

  
        batch += [([], []) for _ in range(BATCH_SIZE - len(batch))]
        # Prepare inputs
We do this here for convenience and to have greater alignment between implementations, but in practise it would be best to do this once in pre-processing.

  
        max_length = len(batch[0][0])
        input_array = np.zeros([len(batch), max_length])
        output_array = np.zeros([len(batch), max_length])
        lengths = np.array([len(v[0]) for v in batch])
        mask = np.zeros([len(batch), max_length])
Convert tokens and tags from strings to numbers using the indices.

        for n, (tokens, tags) in enumerate(batch):
            token_ids = [token_to_id.get(simplify_token(t), 0) for t in tokens]
            tag_ids = [tag_to_id[t] for t in tags]

        for n, (tokens, tags) in enumerate(batch):
            token_ids = [token_to_id.get(simplify_token(t), 0) for t in tokens]
            tag_ids = [tag_to_id[t] for t in tags]

        for n, (tokens, tags) in enumerate(batch):
            token_ids = [token_to_id.get(simplify_token(t), 0) for t in tokens]
            tag_ids = [tag_to_id[t] for t in tags]

Now we define the computation to be performed with the model. Note that they are not applied yet, we are simply building the computation graph.

            # Look up word embeddings
            wembs = [dy.lookup(pEmbedding, w) for w in token_ids]
            # Apply dropout
            if train:
                wembs = [dy.dropout(w, 1.0 - KEEP_PROB) for w in wembs]
            # Feed words into the LSTM
  
Create an expression for two LSTMs and feed in the embeddings (reversed in one case).
We pull out the output vector from the cell state at each step.

            f_init = f_lstm.initial_state()
            f_lstm_output = [x.output() for x in f_init.add_inputs(wembs)]
            rev_embs = reversed(wembs)
            b_init = b_lstm.initial_state()
            b_lstm_output = [x.output() for x in b_init.add_inputs(rev_embs)]

            # For each output, calculate the output and loss
            pred_tags = []
            for f, b, t in zip(f_lstm_output, reversed(b_lstm_output), tag_ids):
                # Combine the outputs
                combined = dy.concatenate([f,b])
                # Apply dropout
                if train:
                    combined = dy.dropout(combined, 1.0 - KEEP_PROB)
                # Matrix multiply to get scores for each tag
                r_t = pOutput * combined
                # Calculate cross-entropy loss
                if train:
                    err = dy.pickneglogsoftmax(r_t, t)
  
We are not actually evaluating the loss values here, instead we collect them together in a list. This enables DyNet's autobatching.

                    loss_expressions.append(err)
                # Calculate the highest scoring tag
  
This call to .npvalue() will lead to evaluation of the graph and so we don't actually get the benefits of autobatching. With some refactoring we could get the benefit back (simply keep the r_t expressions around and do this after the update), but that would have complicated this code.

                chosen = np.argmax(r_t.npvalue())
                pred_tags.append(chosen)
            predicted.append(pred_tags)

        # combine the losses for the batch, do an update, and record the loss
        if train:
            loss_for_batch = dy.esum(loss_expressions)
            loss_for_batch.backward()
            trainer.update()
            loss += loss_for_batch.scalar_value()

  
Fill the arrays, leaving the remaining values as zero (our padding value).

 
            input_array[n, :len(tokens)] = torch.LongTensor(token_ids)
            output_array[n, :len(tags)] = torch.LongTensor(tag_ids)

        # Construct computation
 
Calling the model as a function will run its forward() function, which constructs the computations.

 
        batch_loss, output = model(input_array, output_array, lengths,
                cur_batch_size)

        # Run computations
        if train:
            batch_loss.backward()
            optimizer.step()
            model.zero_grad()
 
To get the loss value we use .item().

 
            loss += batch_loss.item()
 
Our output is an array (rather than a single value), so we use a different approach to get it into a usable form.

 
        predicted = output.cpu().data.numpy()

 
Fill the arrays, leaving the remaining values as zero (our padding value).

  
            input_array[n, :len(tokens)] = token_ids
            output_array[n, :len(tags)] = tag_ids
            mask[n, :len(tokens)] = np.ones([len(tokens)])
We can't change the computation graph to disable dropout when not training, so we just change the keep probability.

  
        cur_keep_prob = KEEP_PROB if train else 1.0
This dictionary contains values for all of the placeholders we defined.

  
        feed = {
                e_input: input_array,
                e_gold_output: output_array,
                e_mask: mask,
                e_keep_prob: cur_keep_prob,
                e_lengths: lengths,
                e_learning_rate: lr
        }

        # Define the computations needed
        todo = [e_auto_output]
If we are not training we do not need to compute a loss and we do not want to do the update.

  
        if train:
            todo.append(e_loss)
            todo.append(e_train)
        # Run computations
        outcomes = session.run(todo, feed_dict=feed)
        # Get outputs
        predicted = outcomes[0]
        if train:
We do not request the e_train value because its work is done - it performed the update during its computation.

  
            loss += outcomes[1]

 
        # Update the number of correct tags and total tags
        for (_, g), a in zip(batch, predicted):
            total += len(g)
            for gt, at in zip(g, a):
                gt = tag_to_id[gt]
                if gt == at:
                    match += 1

    return loss, match / total

if __name__ == '__main__':
    main()
        # Update the number of correct tags and total tags
        for (_, g), a in zip(batch, predicted):
            total += len(g)
            for gt, at in zip(g, a):
                gt = tag_to_id[gt]
                if gt == at:
                    match += 1

    return loss, match / total

if __name__ == '__main__':
    main()
        # Update the number of correct tags and total tags
        for (_, g), a in zip(batch, predicted):
            total += len(g)
            for gt, at in zip(g, a):
                gt = tag_to_id[gt]
                if gt == at:
                    match += 1

    return loss, match / total

if __name__ == '__main__':
    main()

This code was last updated in August 2018. If one of the frameworks has changed in a way that should be reflected here, please let me know!

A few miscellaneous notes:

And a few other gotchas I've come across:

I developed this code and webpage with help from many people and resources. In particular: