ValueError when trying to run OpenAI's GPT-2 text generator

Hello there,

I’m trying to replicate the GPT-2 NLP text generator model that I read about here: https://www.analyticsvidhya.com/blog/2019/07/openai-gpt2-text-generator-python/

I’ve prepped the environment, but unfortunately, when I try to run the model with some standard parameters, I’m seeing the following “ValueError”:

ValueError: Dimensions must be equal, but are 50257 and 100 for ‘sample_sequence/Less’ (op: ‘Less’) with input shapes: [100,50257], [100].

I’m using the same “interact_model” code on the website and the following parameters:
interact_model(‘345M’,12345, 100, 100, 1000, 10, 40, )

Does anyone have any tips for why this error is being thrown?

It looks like the code is attempting to do a “less than” comparison at some point, and because the array shapes are incompatible, the code cannot complete that operation.

Thanks in advance!

def interact_model(

model_name,

seed,

nsamples,

batch_size,

length,

temperature,

top_k,

models_dir

):

models_dir = os.path.expanduser(os.path.expandvars(models_dir))

if batch_size is None:

    batch_size = 1

assert nsamples % batch_size == 0

enc = encoder.get_encoder(model_name, models_dir)

hparams = model.default_hparams()

with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:

    hparams.override_from_dict(json.load(f))

if length is None:

    length = hparams.n_ctx // 2

elif length > hparams.n_ctx:

    raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)

with tf.Session(graph=tf.Graph()) as sess:

    context = tf.placeholder(tf.int32, [batch_size, None])

    np.random.seed(seed)

    tf.set_random_seed(seed)

    output = sample.sample_sequence(

        hparams=hparams, length=length,

        context=context,

        batch_size=batch_size,

        temperature=temperature, top_k=top_k

    )

    saver = tf.train.Saver()

    ckpt = tf.train.latest_checkpoint(os.path.join(models_dir, model_name))

    saver.restore(sess, ckpt)

    while True:

        raw_text = input("Model prompt >>> ")

        while not raw_text:

            print('Prompt should not be empty!')

            raw_text = input("Model prompt >>> ")

        context_tokens = enc.encode(raw_text)

        generated = 0

        for _ in range(nsamples // batch_size):

            out = sess.run(output, feed_dict={

                context: [context_tokens for _ in range(batch_size)]

            })[:, len(context_tokens):]

            for i in range(batch_size):

                generated += 1

                text = enc.decode(out[i])

                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)

                print(text)

        print("=" * 80)
© Copyright 2013-2019 Analytics Vidhya