GPT-2

Created Diff never expires
51 removals
Lines
Total229
Removed-19.7%45
Words
Total583
Removed-12.7%74
229 lines
102 additions
Lines
Total273
Added+35.2%96
Words
Total710
Added+28.3%201
273 lines
# coding=utf-8
# coding=utf-8
#
# Copyright 2021 Biderman et al. This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.


"""GPT-2 model."""
"""GPT-2 model."""


import torch
import torch


from megatron import get_args
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.module import MegatronModule

from functools import partial
from .language_model import parallel_lm_logits
from .language_model import parallel_lm_logits
from .language_model import get_language_model
from .language_model import get_language_model
from .utils import init_method_normal
from .utils import init_method_normal
from .utils import scaled_init_method_normal
from .utils import scaled_init_method_normal
from .norms import LayerNorm, RMSNorm, ScaleNorm


# Pipeline parallelism
# Pipeline parallelism
from megatron import mpu
from megatron import mpu
import torch.nn.functional as F
from megatron.mpu import ParallelRelativePositionBias
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
import megatron.fp16 as fp16
import megatron.fp16 as fp16
from megatron.model.transformer import ParallelTransformerLayerPipe
from megatron.model.transformer import ParallelTransformerLayerPipe
from .language_model import EmbeddingPipe
from .language_model import EmbeddingPipe


from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec
from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec




def gpt2_attention_mask_func(attention_scores, ltor_mask):
def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores
return attention_scores




def CrossEntropy(output, labels):
def cross_entropy(output, labels, _fp16=False):
""" From pretrain_gpt2:forward_step() """
""" From pretrain_gpt2:forward_step() """
labels, loss_mask = labels[0], labels[1]
labels, loss_mask = labels[0], labels[1]

if _fp16:
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)
assert output.dtype == torch.half
losses = mpu.vocab_parallel_cross_entropy(output.contiguous(), labels)
else:
output = fp16.fp16_to_fp32(output)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous(), labels)
loss_mask = loss_mask.view(-1)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss
return loss





class GPT2Model(MegatronModule):
class GPT2Model(MegatronModule):
"""GPT-2 Language model."""
"""GPT-2 Language model."""


def __init__(self, num_tokentypes=0, parallel_output=True):
def __init__(self, num_tokentypes=0, parallel_output=True):
super(GPT2Model, self).__init__()
super(GPT2Model, self).__init__()
args = get_args()
args = get_args()
self.weight_tying = not args.no_weight_tying


self.parallel_output = parallel_output
self.parallel_output = parallel_output
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy


self.language_model, self._language_model_key = get_language_model(
self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
attention_mask_func=gpt2_attention_mask_func,
num_tokentypes=num_tokentypes,
num_tokentypes=num_tokentypes,
add_pooler=False,
add_pooler=False,
init_method=init_method_normal(args.init_method_std),
init_method=init_method_normal(args.init_method_std),
scaled_init_method=scaled_init_method_normal(args.init_method_std,
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
args.num_layers))


def forward(self, input_ids, position_ids, attention_mask, labels=None,
def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):
forward_method_parallel_output=None):


# Language model.
# Language model.
lm_output = self.language_model(input_ids,
lm_output = self.language_model(input_ids,
position_ids,
position_ids,
attention_mask,
attention_mask,
tokentype_ids=tokentype_ids,
tokentype_ids=tokentype_ids,
layer_past=layer_past,
layer_past=layer_past,
get_key_value=get_key_value)
get_key_value=get_key_value)


if get_key_value:
if get_key_value:
lm_output, presents = lm_output
lm_output, presents = lm_output


# Output.
# Output.
parallel_output = self.parallel_output
parallel_output = self.parallel_output
if forward_method_parallel_output is not None:
if forward_method_parallel_output is not None:
parallel_output = forward_method_parallel_output
parallel_output = forward_method_parallel_output
output = parallel_lm_logits(
if self.weight_tying:
lm_output,
output = parallel_lm_logits(
self.language_model.embedding.word_embeddings.weight,
lm_output,
parallel_output)
self.language_model.embedding.word_embeddings.weight,
parallel_output)
else:
output = parallel_lm_logits(
lm_output,
None,
parallel_output, weight_tying=False)


if get_key_value:
if get_key_value:
output = [output, presents]
output = [output, presents]


if labels is None:
if labels is None:
return output
return output
else:
else:
if self.fp16_lm_cross_entropy:
if self.fp16_lm_cross_entropy:
assert output.dtype == torch.half
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss
return loss



def state_dict_for_save_checkpoint(self, destination=None, prefix='',
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
keep_vars=False):
keep_vars=False):


state_dict_ = {}
state_dict_ = {}
state_dict_[self._language_model_key] \
state_dict_[self._language_model_key] \
= self.language_model.state_dict_for_save_checkpoint(
= self.language_model.state_dict_for_save_checkpoint(
destination, prefix, keep_vars)
destination, prefix, keep_vars)
return state_dict_
return state_dict_


def load_state_dict(self, state_dict, strict=True):
def load_state_dict(self, state_dict, strict=True):
"""Customized load."""
"""Customized load."""


if self._language_model_key in state_dict:
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)
self.language_model.load_state_dict(state_dict, strict=strict)




class GPT2ModelPipe(PipelineModule,MegatronModule):
class GPT2ModelPipe(PipelineModule, MegatronModule):
"""GPT2Model adapted for pipeline parallelism.
"""GPT2Model adapted for pipeline parallelism.


The largest change is flattening the GPTModel class so we can express it as a
The largest change is flattening the GPTModel class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
sequence of layers including embedding, transformer layers, and output.
"""
"""


def __init__(self, num_tokentypes=0, parallel_output=True, add_pooler=False, topology=None):
def __init__(self, num_tokentypes=0, parallel_output=True, topology=None):
args = get_args()
args = get_args()


self.parallel_output = parallel_output
self.parallel_output = parallel_output
self.hidden_size = args.hidden_size
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.num_tokentypes = num_tokentypes
self.init_method = init_method_normal(args.init_method_std)
self.init_method = init_method_normal(args.init_method_std)
self.output_layer_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
self.output_layer_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
self.add_pooler = add_pooler
weight_tying = not args.no_weight_tying
if self.add_pooler:
if args.pos_emb == 'rpe':
raise NotImplementedError('Pipeline pooler not yet implemented. Forward needs pooling_sequence_index')
rpe_emb = ParallelRelativePositionBias(causal=True, num_buckets=args.rpe_num_buckets, max_distance=args.rpe_max_distance,
heads=args.num_attention_heads)
else:
rpe_emb = None
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu

#
#
# forward() prototype
# forward() prototype
#
#
self.specs = []
self.specs = []

# Embedding layer
# Embedding layer
self.specs.append(TiedLayerSpec('embed',
if weight_tying:
EmbeddingPipe,
self.specs.append(TiedLayerSpec('embed',
EmbeddingPipe,
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
tied_weight_attr='word_embeddings_weight'))
else:
self.specs.append(LayerSpec(EmbeddingPipe,
self.hidden_size,
self.hidden_size,
args.padded_vocab_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.max_position_embeddings,
args.hidden_dropout,
args.hidden_dropout,
self.init_method,
self.init_method,
self.num_tokentypes,
self.num_tokentypes))
tied_weight_attr='word_embeddings_weight'))


# outputs are now (hidden_states, attention_mask)
# outputs are now (hidden_states, attention_mask)


# data format change to avoid explicit tranposes : [b s h] --> [s b h]
# data format change to avoid explicit tranposes : [b s h] --> [s b h]
self.specs.append(lambda x: (x[0].transpose(0,1).contiguous(), x[1]))
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous(), x[1]))

# Transformer layers
# Transformer layers
for x in range(args.num_layers):
for x in range(args.num_layers):
if args.sparsity == 'none':
sparse = False
elif args.sparsity == 'all':
sparse = True
elif args.sparsity == 'interspersed':
sparse = not x % 2 == 0
self.specs.append(
self.specs.append(
LayerSpec(ParallelTransformerLayerPipe,
LayerSpec(ParallelTransformerLayerPipe,
attention_mask_func=gpt2_attention_mask_func,
attention_mask_func=gpt2_attention_mask_func,
init_method=self.init_method,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=x))
layer_number=x,
sparse=sparse,
rpe=rpe_emb))
# Undo data format change and drop mask
# Undo data format change and drop mask
self.specs.append(lambda x: x[0].transpose(0,1).contiguous())
self.specs.append(lambda x: x[0].transpose(0, 1).contiguous())



# Final layernorm after transformer layers
# Final layernorm after transformer layers
if args.norm == "rmsnorm":
norm = RMSNorm
eps = args.rms_norm_epsilon
elif args.norm == "layernorm":
eps = args.layernorm_epsilon
norm = LayerNorm
elif args.norm == "scalenorm":
eps = args.scalenorm_epsilon
norm = ScaleNorm

self.specs.append(
self.specs.append(
LayerSpec(LayerNorm,
LayerSpec(norm,
args.hidden_size,
args.hidden_size,
eps=args.layernorm_epsilon))
eps=eps))


# XXX forward_method_parallel_output is assumed to be None, but we're not in a
# XXX forward_method_parallel_output is assumed to be None, but we're not in a
# fwd method to assert
# fwd method to assert


def _logits_helper(embedding, lm_output):
def _logits_helper(embedding, lm_output):
"""Just a wrapper to massage inputs/outputs from pipeline. """
"""Just a wrapper to massage inputs/outputs from pipeline. """
return parallel_lm_logits(
return parallel_lm_logits(
lm_output,
lm_output,
embedding.word_embeddings_weight,
embedding.word_embeddings_weight,
self.parallel_output)
self.parallel_output)


self.specs.append(
if weight_tying:
TiedLayerSpec('embed',
self.specs.append(
EmbeddingPipe,
TiedLayerSpec('embed',
self.hidden_size,
EmbeddingPipe,
args.padded_vocab_size,
self.hidden_size,
args.max_position_embeddings,
args.padded_vocab_size,
args.hidden_dropout,
args.max_position_embeddings,
self.init_method,
args.hidden_dropout,
self.num_tokentypes,
self.init_method,
forward_fn=_logits_helper,
self.num_tokentypes,
tied_weight_attr='word_embeddings_weight')
forward_fn=_logits_helper,
)
tied_weight_attr='word_embeddings_weight')

)
# Should maybe be done in loss_fn() instead?
else:
if args.fp16:
self.specs.append(
self.specs.append(fp16.fp16_to_fp32)
LayerSpec(
mpu.RowParallelLinear,
args.hidden_size,
args.padded_vocab_size,
bias=False,
input_is_parallel=False,
parallel_output=True,
skip_bias_add=False
)
)
self.specs.append(lambda x: x[0]) # drop bias


loss_fn = partial(cross_entropy, _fp16=self.fp16_lm_cross_entropy)
if args.checkpoint_activations:
if args.checkpoint_activations:
interval = args.checkpoint_num_layers
interval = args.checkpoint_num_layers
else:
else:
interval = 0
interval = 0
super().__init__(layers=self.specs,
super().__init__(layers=self.specs,
loss_fn=CrossEntropy,
loss_fn=loss_fn,
topology=topology,
topology=topology,
activation_checkpoint_interval=interval,
activation_checkpoint_interval=interval,
partition_method='type:transformer')
partition_method='type:transformer')