Skip to content

Commit 5c9420d

Browse files
suiyoubinbroad1881
authored andcommitted
Add Nemotron HF Support (#31699)
* Add nemotron support * fix inference * add unit test * add layernorm1p as a class to avoid meta device mismatch * test fixed * Add copied_from statements * remove pretraining_tp args * remove nemotronlayernorm * force LN computation done in FP32 * remove nemotrontokenizer and use llamatokenizer * license update * add option for kv_channels for minitron8b * remove assert * o_proj fixed * o_proj reshape * add gated_proj option * typo * remove todos * fix broken test after merging latest main * remove nezha/nat after meging main * chnage default config to 15b model * add nemo conversion script * rename conversion script * remove gate_proj option * pr comment resolved * fix unit test * rename kv_channels to head_dim * resolve PR issue * add nemotron md * fix broken tests * refactor rope for nemotron * test fix * remove linearscaling * whitespace and import * fix some copied-from * code style fix * reformatted * add position_embedding to nemotronattention * rope refactor to only use config, copied-from fix * format * Run make fix-copies * nemotron md with autodoc * doc fix * fix order * pass check_config_docstrings.py * fix config_attributes * remove all llama BC related code * Use PreTrainedTokenizerFast * ruff check examples * conversion script update * add nemotron to toctree
1 parent 2a0e440 commit 5c9420d

File tree

15 files changed

+2449
-0
lines changed

15 files changed

+2449
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,8 @@
468468
title: MT5
469469
- local: model_doc/mvp
470470
title: MVP
471+
- local: model_doc/nemotron
472+
title: Nemotron
471473
- local: model_doc/nezha
472474
title: NEZHA
473475
- local: model_doc/nllb

docs/source/en/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ Flax), PyTorch, and/or TensorFlow.
222222
| [MusicGen Melody](model_doc/musicgen_melody) ||||
223223
| [MVP](model_doc/mvp) ||||
224224
| [NAT](model_doc/nat) ||||
225+
| [Nemotron](model_doc/nemotron) ||||
225226
| [Nezha](model_doc/nezha) ||||
226227
| [NLLB](model_doc/nllb) ||||
227228
| [NLLB-MOE](model_doc/nllb-moe) ||||

docs/source/en/model_doc/nemotron.md

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
5+
the License. You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
10+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
11+
specific language governing permissions and limitations under the License.
12+
13+
-->
14+
15+
# Nemotron
16+
17+
## Nemotron
18+
19+
### License
20+
21+
The use of this model is governed by the [NVIDIA AI Foundation Models Community License Agreement](https://developer.nvidia.com/downloads/nv-ai-foundation-models-license).
22+
23+
### Description
24+
25+
Nemotron-4 is a family of enterprise ready generative text models compatible with [NVIDIA NeMo Framework](https://www.nvidia.com/en-us/ai-data-science/generative-ai/nemo-framework/).
26+
27+
NVIDIA NeMo is an end-to-end, cloud-native platform to build, customize, and deploy generative AI models anywhere. It includes training and inferencing frameworks, guardrailing toolkits, data curation tools, and pretrained models, offering enterprises an easy, cost-effective, and fast way to adopt generative AI. To get access to NeMo Framework, please sign up at [this link](https://developer.nvidia.com/nemo-framework/join).
28+
29+
### References
30+
31+
[Announcement Blog](https://developer.nvidia.com/blog/nvidia-ai-foundation-models-build-custom-enterprise-chatbots-and-co-pilots-with-production-ready-llms/)
32+
33+
### Model Architecture
34+
35+
**Architecture Type:** Transformer
36+
37+
**Network Architecture:** Transformer Decoder (auto-regressive language model).
38+
39+
## Minitron
40+
41+
### Minitron 4B Base
42+
43+
Minitron is a family of small language models (SLMs) obtained by pruning NVIDIA's [Nemotron-4 15B](https://arxiv.org/abs/2402.16819) model. We prune model embedding size, attention heads, and MLP intermediate dimension, following which, we perform continued training with distillation to arrive at the final models.
44+
45+
Deriving the Minitron 8B and 4B models from the base 15B model using our approach requires up to **40x fewer training tokens** per model compared to training from scratch; this results in **compute cost savings of 1.8x** for training the full model family (15B, 8B, and 4B). Minitron models exhibit up to a 16% improvement in MMLU scores compared to training from scratch, perform comparably to other community models such as Mistral 7B, Gemma 7B and Llama-3 8B, and outperform state-of-the-art compression techniques from the literature. Please refer to our [arXiv paper](https://arxiv.org/abs/2407.14679) for more details.
46+
47+
Minitron models are for research and development only.
48+
49+
### HuggingFace Quickstart
50+
51+
The following code provides an example of how to load the Minitron-4B model and use it to perform text generation.
52+
53+
```python
54+
import torch
55+
from transformers import AutoTokenizer, AutoModelForCausalLM
56+
57+
# Load the tokenizer and model
58+
model_path = 'nvidia/Minitron-4B-Base'
59+
tokenizer = AutoTokenizer.from_pretrained(model_path)
60+
61+
device = 'cuda'
62+
dtype = torch.bfloat16
63+
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=dtype, device_map=device)
64+
65+
# Prepare the input text
66+
prompt = 'Complete the paragraph: our solar system is'
67+
inputs = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
68+
69+
# Generate the output
70+
outputs = model.generate(inputs, max_length=20)
71+
72+
# Decode and print the output
73+
output_text = tokenizer.decode(outputs[0])
74+
print(output_text)
75+
```
76+
77+
### License
78+
79+
Minitron is released under the [NVIDIA Open Model License Agreement](https://developer.download.nvidia.com/licenses/nvidia-open-model-license-agreement-june-2024.pdf).
80+
81+
### Evaluation Results
82+
83+
*5-shot performance.* Language Understanding evaluated using [Massive Multitask Language Understanding](https://arxiv.org/abs/2009.03300):
84+
85+
| Average |
86+
| :---- |
87+
| 58.6 |
88+
89+
*Zero-shot performance.* Evaluated using select datasets from the [LM Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness) with additions:
90+
91+
| HellaSwag | Winogrande | GSM8K| ARC-C | XLSum |
92+
| :------------- | :------------- | :------------- | :------------- | :------------- |
93+
| 75.0 | 74.0 | 24.1 | 50.9 | 29.5
94+
95+
96+
*Code generation performance*. Evaluated using [HumanEval](https://github.com/openai/human-eval):
97+
98+
| p@1, 0-Shot |
99+
| :------------- |
100+
| 23.3 |
101+
102+
Please refer to our [paper](https://arxiv.org/abs/2407.14679) for the full set of results.
103+
104+
### Citation
105+
106+
If you find our work helpful, please consider citing our paper:
107+
```
108+
@article{minitron2024,
109+
title={Compact Language Models via Pruning and Knowledge Distillation},
110+
author={Saurav Muralidharan and Sharath Turuvekere Sreenivas and Raviraj Joshi and Marcin Chochowski and Mostofa Patwary and Mohammad Shoeybi and Bryan Catanzaro and Jan Kautz and Pavlo Molchanov},
111+
journal={arXiv preprint arXiv:2407.14679},
112+
year={2024},
113+
url={https://arxiv.org/abs/2407.14679},
114+
}
115+
```
116+
117+
## NemotronConfig
118+
119+
[[autodoc]] NemotronConfig
120+
121+
122+
## NemotronModel
123+
124+
[[autodoc]] NemotronModel
125+
- forward
126+
127+
128+
## NemotronForCausalLM
129+
130+
[[autodoc]] NemotronForCausalLM
131+
- forward
132+
133+
## NemotronForSequenceClassification
134+
135+
[[autodoc]] NemotronForSequenceClassification
136+
- forward
137+
138+
139+
## NemotronForQuestionAnswering
140+
141+
[[autodoc]] NemotronForQuestionAnswering
142+
- forward
143+
144+
145+
## NemotronForTokenClassification
146+
147+
[[autodoc]] NemotronForTokenClassification
148+
- forward

docs/source/en/perf_infer_gpu_one.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ FlashAttention-2 is currently supported for the following architectures:
6767
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
6868
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
6969
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
70+
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
7071
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
7172
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
7273
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
@@ -228,6 +229,7 @@ For now, Transformers supports SDPA inference and training for the following arc
228229
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
229230
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
230231
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
232+
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
231233
* [ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.ViTModel)
232234
* [ViTHybrid](https://huggingface.co/docs/transformers/model_doc/vit_hybrid#transformers.ViTHybridModel)
233235
* [ViTMAE](https://huggingface.co/docs/transformers/model_doc/vit_mae#transformers.ViTMAEModel)

src/transformers/__init__.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,7 @@
592592
"MusicgenMelodyDecoderConfig",
593593
],
594594
"models.mvp": ["MvpConfig", "MvpTokenizer"],
595+
"models.nemotron": ["NemotronConfig"],
595596
"models.nllb": [],
596597
"models.nllb_moe": ["NllbMoeConfig"],
597598
"models.nougat": ["NougatProcessor"],
@@ -2742,6 +2743,16 @@
27422743
"MvpPreTrainedModel",
27432744
]
27442745
)
2746+
_import_structure["models.nemotron"].extend(
2747+
[
2748+
"NemotronForCausalLM",
2749+
"NemotronForQuestionAnswering",
2750+
"NemotronForSequenceClassification",
2751+
"NemotronForTokenClassification",
2752+
"NemotronModel",
2753+
"NemotronPreTrainedModel",
2754+
]
2755+
)
27452756
_import_structure["models.nllb_moe"].extend(
27462757
[
27472758
"NllbMoeForConditionalGeneration",
@@ -5286,6 +5297,7 @@
52865297
MusicgenMelodyDecoderConfig,
52875298
)
52885299
from .models.mvp import MvpConfig, MvpTokenizer
5300+
from .models.nemotron import NemotronConfig
52895301
from .models.nllb_moe import NllbMoeConfig
52905302
from .models.nougat import NougatProcessor
52915303
from .models.nystromformer import (
@@ -7187,6 +7199,14 @@
71877199
MvpModel,
71887200
MvpPreTrainedModel,
71897201
)
7202+
from .models.nemotron import (
7203+
NemotronForCausalLM,
7204+
NemotronForQuestionAnswering,
7205+
NemotronForSequenceClassification,
7206+
NemotronForTokenClassification,
7207+
NemotronModel,
7208+
NemotronPreTrainedModel,
7209+
)
71907210
from .models.nllb_moe import (
71917211
NllbMoeForConditionalGeneration,
71927212
NllbMoeModel,

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@
159159
musicgen,
160160
musicgen_melody,
161161
mvp,
162+
nemotron,
162163
nllb,
163164
nllb_moe,
164165
nougat,

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@
177177
("musicgen_melody", "MusicgenMelodyConfig"),
178178
("mvp", "MvpConfig"),
179179
("nat", "NatConfig"),
180+
("nemotron", "NemotronConfig"),
180181
("nezha", "NezhaConfig"),
181182
("nllb-moe", "NllbMoeConfig"),
182183
("nougat", "VisionEncoderDecoderConfig"),
@@ -469,6 +470,7 @@
469470
("musicgen_melody", "MusicGen Melody"),
470471
("mvp", "MVP"),
471472
("nat", "NAT"),
473+
("nemotron", "Nemotron"),
472474
("nezha", "Nezha"),
473475
("nllb", "NLLB"),
474476
("nllb-moe", "NLLB-MOE"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@
169169
("musicgen_melody", "MusicgenMelodyModel"),
170170
("mvp", "MvpModel"),
171171
("nat", "NatModel"),
172+
("nemotron", "NemotronModel"),
172173
("nezha", "NezhaModel"),
173174
("nllb-moe", "NllbMoeModel"),
174175
("nystromformer", "NystromformerModel"),
@@ -481,6 +482,7 @@
481482
("musicgen", "MusicgenForCausalLM"),
482483
("musicgen_melody", "MusicgenMelodyForCausalLM"),
483484
("mvp", "MvpForCausalLM"),
485+
("nemotron", "NemotronForCausalLM"),
484486
("olmo", "OlmoForCausalLM"),
485487
("open-llama", "OpenLlamaForCausalLM"),
486488
("openai-gpt", "OpenAIGPTLMHeadModel"),
@@ -902,6 +904,7 @@
902904
("mra", "MraForSequenceClassification"),
903905
("mt5", "MT5ForSequenceClassification"),
904906
("mvp", "MvpForSequenceClassification"),
907+
("nemotron", "NemotronForSequenceClassification"),
905908
("nezha", "NezhaForSequenceClassification"),
906909
("nystromformer", "NystromformerForSequenceClassification"),
907910
("open-llama", "OpenLlamaForSequenceClassification"),
@@ -983,6 +986,7 @@
983986
("mra", "MraForQuestionAnswering"),
984987
("mt5", "MT5ForQuestionAnswering"),
985988
("mvp", "MvpForQuestionAnswering"),
989+
("nemotron", "NemotronForQuestionAnswering"),
986990
("nezha", "NezhaForQuestionAnswering"),
987991
("nystromformer", "NystromformerForQuestionAnswering"),
988992
("opt", "OPTForQuestionAnswering"),
@@ -1078,6 +1082,7 @@
10781082
("mpt", "MptForTokenClassification"),
10791083
("mra", "MraForTokenClassification"),
10801084
("mt5", "MT5ForTokenClassification"),
1085+
("nemotron", "NemotronForTokenClassification"),
10811086
("nezha", "NezhaForTokenClassification"),
10821087
("nystromformer", "NystromformerForTokenClassification"),
10831088
("persimmon", "PersimmonForTokenClassification"),
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from typing import TYPE_CHECKING
16+
17+
from ...utils import (
18+
OptionalDependencyNotAvailable,
19+
_LazyModule,
20+
is_sentencepiece_available,
21+
is_torch_available,
22+
)
23+
24+
25+
_import_structure = {
26+
"configuration_nemotron": ["NemotronConfig"],
27+
}
28+
29+
30+
try:
31+
if not is_torch_available():
32+
raise OptionalDependencyNotAvailable()
33+
except OptionalDependencyNotAvailable:
34+
pass
35+
else:
36+
_import_structure["modeling_nemotron"] = [
37+
"NemotronForQuestionAnswering",
38+
"NemotronForCausalLM",
39+
"NemotronModel",
40+
"NemotronPreTrainedModel",
41+
"NemotronForSequenceClassification",
42+
"NemotronForTokenClassification",
43+
]
44+
45+
46+
if TYPE_CHECKING:
47+
from .configuration_nemotron import NemotronConfig
48+
49+
try:
50+
if not is_torch_available():
51+
raise OptionalDependencyNotAvailable()
52+
except OptionalDependencyNotAvailable:
53+
pass
54+
else:
55+
from .modeling_nemotron import (
56+
NemotronForCausalLM,
57+
NemotronForQuestionAnswering,
58+
NemotronForSequenceClassification,
59+
NemotronForTokenClassification,
60+
NemotronModel,
61+
NemotronPreTrainedModel,
62+
)
63+
64+
65+
else:
66+
import sys
67+
68+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)

0 commit comments

Comments
 (0)