%%capture
!pip install unslothTrain on Massive Datasets Without Downloading with Hugging Face Streaming and Unsloth
GPU poor and disk poor?
Unsloth has massively lowered the barriers to training and fine-tuning models by reducing the GPU resources required.
However, many AI datasets are very large — often multiple TBs. What if you want to train on a dataset larger than your disk?
Using 🤗 datasets + streaming means you can train directly from a Hugging Face hosted dataset without needing to download and store the whole dataset locally. This means even GPU and disk poor people can do things like continued pretraining!
Can we make Qwen speak Latin?
FineWeb-2 has 1.47 million Latin texts (~1.7GB). Let’s use them to try to teach a small LLM some Latin - no disk space required.
This is perfect for “GPU poor AND disk poor” setups - Colab, Kaggle, or any constrained environment.
from unsloth import FastLanguageModel
import torch
model, tokenizer = FastLanguageModel.from_pretrained(
"unsloth/Qwen3-0.6B-Base-unsloth-bnb-4bit",
max_seq_length=2048,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=16,
lora_alpha=16,
lora_dropout=0,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
)==((====))== Unsloth 2026.1.2: Fast Qwen3 patching. Transformers: 4.57.3.
\\ /| NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \ Torch: 2.9.1+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.1
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
"-____-" Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Vibe Check: Before Training
Let’s see what this base model generates from a Latin prompt before any Latin training.
# Before any Latin training - what does the base model produce?
FastLanguageModel.for_inference(model)
inputs = tokenizer("Lingua Latina est", return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=64,
temperature=0.7,
do_sample=True,
)
print("BEFORE:", tokenizer.decode(outputs[0], skip_special_tokens=True))BEFORE: Lingua Latina estémao
# 04555 - Pórticos, conyugues y mujeres
Hasta ahora, el libro de esta edición ha tratado de la historia del púrpico en todo el mundo, pero sobre la mujer. No se trata, por lo tanto, de un
The Key: Streaming Dataset
The magic is streaming=True - data flows directly from the Hugging Face Hub without you first needing to download data locally.
from datasets import load_dataset
dataset = load_dataset(
"HuggingFaceFW/fineweb-2",
name="lat_Latn",
split="train",
streaming=True
)
# Peek at the data
next(iter(dataset)){'text': 'Ita est in oratione senex mente confusus, eo quod illam imaginem Deitatis, quam proponere sibi in oratione consueverat, aboleri de suo corde sentiret, ut in amarissimos fletus, crebrosque singultus repente prorumpens, in terram prostratus, cum ejulatu validissimo proclamaret; "Heu me miserum! tulerunt a me Deum meum, et quem nunc teneam non habeo, vel quem adorem, aut interpallam am nescio." Cassian, Collat. x. 2.',
'id': '<urn:uuid:318d65fb-88ea-43cd-8687-8a8d802317a5>',
'dump': 'CC-MAIN-2013-20',
'url': 'http://www.ourcivilisation.com/smartboard/shop/gibbone/rome/volume2/nt470/013.htm',
'date': '2013-05-18T13:41:33Z',
'file_path': 's3://commoncrawl/crawl-data/CC-MAIN-2013-20/segments/1368696382398/warc/CC-MAIN-20130516092622-00034-ip-10-60-113-184.ec2.internal.warc.gz',
'language': 'lat',
'language_score': 0.9927687644958496,
'language_script': 'Latn',
'minhash_cluster_size': 10,
'top_langs': '{}'}
def format_text(example):
return {"text": example["text"] + tokenizer.eos_token}
formatted_dataset = dataset.map(format_text)example = dataset.map(format_text)
exampleIterableDataset({
features: Unknown,
num_shards: 1
})
next(iter(example))['text']'Ita est in oratione senex mente confusus, eo quod illam imaginem Deitatis, quam proponere sibi in oratione consueverat, aboleri de suo corde sentiret, ut in amarissimos fletus, crebrosque singultus repente prorumpens, in terram prostratus, cum ejulatu validissimo proclamaret; "Heu me miserum! tulerunt a me Deum meum, et quem nunc teneam non habeo, vel quem adorem, aut interpallam am nescio." Cassian, Collat. x. 2.<|endoftext|>'
Training
For streaming datasets, use max_steps instead of epochs.
from trl import SFTTrainer, SFTConfig
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=formatted_dataset,
args=SFTConfig(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
warmup_steps=5,
max_steps=100, # Use max_steps, not epochs!
learning_rate=2e-4,
logging_steps=10,
optim="adamw_8bit",
weight_decay=0.01,
lr_scheduler_type="linear",
seed=3407,
output_dir="outputs",
report_to="none",
dataset_text_field="text",
max_seq_length=2048,
packing=False,
),
)
trainer.train()The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))== Unsloth - 2x faster free finetuning | Num GPUs used = 1
\\ /| Num examples = 800 | Num Epochs = 9,223,372,036,854,775,807 | Total steps = 100
O^O/ \_/ \ Batch size per device = 2 | Gradient accumulation steps = 4
\ / Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
"-____-" Trainable parameters = 10,092,544 of 606,142,464 (1.67% trained)
Unsloth: Will smartly offload gradients to save VRAM!
| Step | Training Loss |
|---|---|
| 10 | 3.229800 |
| 20 | 3.384800 |
| 30 | 3.717900 |
| 40 | 3.657300 |
| 50 | 3.491100 |
| 60 | 3.537200 |
| 70 | 3.623300 |
| 80 | 3.539600 |
| 90 | 3.585700 |
| 100 | 3.716300 |
TrainOutput(global_step=100, training_loss=3.5483124160766604, metrics={'train_runtime': 185.9221, 'train_samples_per_second': 4.303, 'train_steps_per_second': 0.538, 'total_flos': 2425465405440000.0, 'train_loss': 3.5483124160766604, 'epoch': 1.0})
Vibe Check: After Training
Same prompt, same settings - let’s see if 100 steps of Latin made a difference (probably not…)
# After Latin training - same prompt, same settings
FastLanguageModel.for_inference(model)
inputs = tokenizer("Lingua Latina est", return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=64,
temperature=0.7,
do_sample=True,
)
print("AFTER:", tokenizer.decode(outputs[0], skip_special_tokens=True))AFTER: Lingua Latina est invenire, ut invenire est. Ego, invenire, invenire, ut, ut, ut, ut, ut, ut, ut ut, ut, ut, ut, ut. Invenire, invenire, invenire, ut, ut, ut, ut,
Scaling Up with HF Jobs
The notebook above works great for quick experiments. But streaming in Colab isn’t as fast as it could be since the network speed in Colab can be quite slow.
What if we ran training directly on Hugging Face?
With HF Jobs, compute is co-located with the data and Jobs have very fast connection so it can be much faster!
| Environment | Speed | Bottleneck |
|---|---|---|
| Colab A100 | 0.36 it/s | Network latency |
| HF Jobs A100 | 0.74 it/s | GPU |
This streaming approach was fairly naive so there are likely some extra tweaks to make it even faster!
I trained a larger model (Qwen3 4B) for 1000 steps using a UV script on HF Jobs. The result: davanstrien/qwen3-4b-latin
Try the Trained Model
# Load the Latin-trained 4B model
from unsloth import FastLanguageModel
model, tokenizer = FastLanguageModel.from_pretrained(
"davanstrien/qwen3-4b-latin",
max_seq_length=2048,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
# Generate some Latin
inputs = tokenizer("Lingua Latina est", return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=64,
temperature=0.7,
do_sample=True,
)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))==((====))== Unsloth 2026.1.2: Fast Qwen3 patching. Transformers: 4.57.3.
\\ /| NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \ Torch: 2.9.1+cu128. CUDA: 8.0. CUDA Toolkit: 12.8. Triton: 3.5.1
\ / Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
"-____-" Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Lingua Latina est etum, nec.,usis. atit, et.,.. n.um sedam,em ei.,,que
,.am a. quisus et.,,, et a inus in e. aliqu.is,,....
,.,at ipsum
Still bad, but more Latin for sure!
The goal in this post wasn’t really to show how to continued fine tuning effectively, but to show how you can avoid needing huge amounts of disk space to train models effectively.
Why This Matters
- No disk space needed - train on massive datasets without downloading
- Works everywhere - Colab, Kaggle, HF Jobs, any constrained environment
- Any language - FineWeb-2 has 90+ languages available
- Scales up - combine with HF Jobs for 2x faster streaming
The full UV script is available here if you want to train your own.
Performance Tips
If streaming feels slow, you can tune PyArrow’s prefetching when loading the dataset:
import pyarrow
import pyarrow.dataset
fragment_scan_options = pyarrow.dataset.ParquetFragmentScanOptions(
cache_options=pyarrow.CacheOptions(
prefetch_limit=1, # prefetch chunks in background
range_size_limit=128 << 20 # 128 MiB minimum request size (default 32 MiB)
),
)
dataset = load_dataset(..., streaming=True, fragment_scan_options=fragment_scan_options)Learn More
- Streaming Datasets for ML Training - HF blog with more performance tuning tips
- Datasets Streaming Documentation - full docs