Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module replacement support #586

Merged
merged 39 commits into from
Jan 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
c9fb0c9
remove transformer layer ID from the top module
Jun 12, 2020
6254e46
updating docstring
Jun 12, 2020
3eab150
add inject
jeffra Jun 12, 2020
f6baecb
update inject PoC
jeffra Jun 17, 2020
b5acaca
fix the preln injection
Jun 18, 2020
b3f99b7
fix the preln injection
Jun 18, 2020
04a4d35
backward-test fixed
Jun 19, 2020
fca500f
backward-test fixed
Jun 19, 2020
f3ff21e
update with replace module style
jeffra Jun 25, 2020
344b016
Merge branch 'master' into jeffra/inject
jeffra Jun 25, 2020
c208bdf
add function to revert from ds kernel -> orig layer
jeffra Jul 16, 2020
c278562
add code from Elton to do ds kernel -> orig layer conversion
jeffra Jul 17, 2020
68d8c13
formatting
jeffra Jul 17, 2020
3161565
update replace to fix runtime errors
jeffra Jul 22, 2020
798e6d3
remove pillow
jeffra Jul 29, 2020
66f590d
remove transformer layer ID from the top module
Jun 12, 2020
e4b46fb
updating docstring
Jun 12, 2020
25ee5e7
add inject
jeffra Jun 12, 2020
d5d10e9
update inject PoC
jeffra Jun 17, 2020
e090049
fix the preln injection
Jun 18, 2020
3df72f8
fix the preln injection
Jun 18, 2020
41cc4e6
backward-test fixed
Jun 19, 2020
24a3d24
backward-test fixed
Jun 19, 2020
66b4e63
update with replace module style
jeffra Jun 25, 2020
ee40034
add function to revert from ds kernel -> orig layer
jeffra Jul 16, 2020
e982c65
add code from Elton to do ds kernel -> orig layer conversion
jeffra Jul 17, 2020
fd4d0bc
formatting
jeffra Jul 17, 2020
e332d61
update replace to fix runtime errors
jeffra Jul 22, 2020
5814254
rebase-complete
Dec 4, 2020
55ffb88
merging
Dec 4, 2020
f48c52a
resolve conflict
Dec 4, 2020
0df72a0
remove dup line and add local-rank parameter to replace function
Dec 4, 2020
b28bf16
Merge pull request #585 from microsoft/jeffra/inject
jeffra Dec 7, 2020
f7dbf11
Merge branch 'master' into jeffra/inject_v2
jeffra Dec 7, 2020
fc39aa2
Merge branch 'master' into jeffra/inject_v2
jeffra Dec 9, 2020
ee13f81
remove remaining rebase conflict text
jeffra Dec 9, 2020
e5d3b50
Merge branch 'jeffra/inject_v2' of github.com:microsoft/DeepSpeed int…
jeffra Dec 9, 2020
1b42798
Kernel injection fixes (#601)
tjruwase Dec 15, 2020
a5a34c6
Merge branch 'master' into jeffra/inject_v2
jeffra Jan 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
122 changes: 122 additions & 0 deletions deepspeed/module_inject/inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import copy
import torch
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig


def module_inject(layer_obj,
model,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16=True):
for name, child in model.named_children():
if isinstance(child, layer_obj):
print('REPLACING BertLayer')

cuda_config = DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln)

new_module = DeepSpeedTransformerLayer(cuda_config)

# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias

qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)

new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layerNorm = child.PostAttentionLayerNorm
else:
attention_layerNorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layerNorm.weight
new_module.attn_nb.data = attention_layerNorm.bias
if preln:
intermediate_FF = child.intermediate.dense_act
else:
intermediate_FF = child.intermediate.dense
new_module.inter_w.data = intermediate_FF.weight
new_module.inter_b.data = intermediate_FF.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_LayerNorm = child.PreAttentionLayerNorm
else:
transformer_LayerNorm = child.output.LayerNorm
new_module.norm_w.data = transformer_LayerNorm.weight
new_module.norm_b.data = transformer_LayerNorm.bias

setattr(model, name, copy.deepcopy(new_module))

else:
module_inject(layer_obj,
child,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16)

return model


def test_hi():
from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN
from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN
from turing.nvidia_modelingpreln import BertLayer
bert_model_config = {
"vocab_size_or_config_json_file": 119547,
"hidden_size": 1024,
"num_hidden_layers": 1,
"num_attention_heads": 16,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02
}
bert_config = BertConfigPreLN(**bert_model_config)
base_model = BertForQuestionAnsweringPreLN(bert_config, args=None)

#base_model = LinearStack()

test_model = copy.deepcopy(base_model)
test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234)

print('BASE', base_model)
print('TEST', test_model)

#base_model.eval()
#test_model.eval()

#test_input = torch.rand(1, base_model.input_dim)

#base_output = base_model(test_input)
#test_output = test_model(test_input)
#
#assert torch.allclose(base_output, test_output, atol=3e-8)
192 changes: 192 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import copy
import torch
import deepspeed


def replace_transformer_layer(orig_layer_impl,
model,
micro_batch_size,
bert_config,
seed,
max_seq_length,
preln=False,
fp16=True,
huggingface=False,
local_rank=-1):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
micro_batch_size (int): micro batch size per gpu used during training/eval
bert_config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
max_seq_length (int): max sequence length for training
preln (bool): does the original layer implementation do pre or post layer norm?
fp16 (bool): fp16 or fp32
huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes)

Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=bert_config.hidden_size,
heads=bert_config.num_attention_heads,
attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
hidden_dropout_ratio=bert_config.hidden_dropout_prob,
num_hidden_layers=bert_config.num_hidden_layers,
initializer_range=bert_config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln,
huggingface=huggingface,
local_rank=local_rank)
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)

# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias

qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)

#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)

new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layernorm = child.PostAttentionLayerNorm
else:
attention_layernorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if preln:
intermediate_ff = child.intermediate.dense_act
else:
intermediate_ff = child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_layernorm = child.PreAttentionLayerNorm
else:
transformer_layernorm = child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias
return new_module

return replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn)


def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.

Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)

# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data

qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)

orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb

orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data

attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if preln:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b

inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if preln:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b

orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data

transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if preln:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module

return replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)


def replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.

Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module(model, policy)


def _replace_module(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.

Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
else:
_replace_module(child, policies)

return model
3 changes: 3 additions & 0 deletions deepspeed/ops/__init__.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,7 @@
from . import sparse_attention
from . import transformer

from .transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject import replace_module

from ..git_version_info import compatible_ops as __compatible_ops__
Loading
Loading