This commit is contained in:
Alan 2026-03-25 18:24:39 +11:00
commit 0e58d36bcb
4 changed files with 115 additions and 0 deletions

91
train.py Normal file
View file

@ -0,0 +1,91 @@
import marimo
__generated_with = "0.21.1"
app = marimo.App(width="full")
@app.cell
def _():
from unsloth import FastLanguageModel
import torch
from datasets import load_dataset
from trl import SFTTrainer, SFTConfig
return FastLanguageModel, SFTConfig, SFTTrainer, load_dataset
@app.cell
def _(load_dataset):
max_seq_length = 2048 # start small; scale up after it works
# Example dataset (replace with yours). Needs a "text" column.
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
dataset = load_dataset("json", data_files={"train": url}, split="train")
return dataset, max_seq_length
@app.cell
def _(dataset):
dataset.to_pandas().head(50)
return
@app.cell
def _(FastLanguageModel, max_seq_length):
# unsloth/Qwen3.5-0.8B
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = "unsloth/Qwen3.5-0.8B",
max_seq_length = max_seq_length,
load_in_4bit = False, # MoE QLoRA not recommended, dense 27B is fine
load_in_16bit = True, # bf16/16-bit LoRA
full_finetuning = False,
)
model = FastLanguageModel.get_peft_model(
model,
r = 16,
target_modules = [
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha = 16,
lora_dropout = 0,
bias = "none",
# "unsloth" checkpointing is intended for very long context + lower VRAM
use_gradient_checkpointing = "unsloth",
random_state = 3407,
max_seq_length = max_seq_length,
)
return model, tokenizer
@app.cell
def _(SFTConfig, SFTTrainer, dataset, max_seq_length, model, tokenizer):
trainer = SFTTrainer(
model = model,
train_dataset = dataset,
tokenizer = tokenizer,
args = SFTConfig(
max_seq_length = max_seq_length,
per_device_train_batch_size = 1,
gradient_accumulation_steps = 4,
warmup_steps = 10,
max_steps = 100,
logging_steps = 1,
output_dir = "outputs_qwen35",
optim = "adamw_8bit",
seed = 3407,
dataset_num_proc = 1,
),
)
return (trainer,)
@app.cell
def _(trainer):
trainer.train()
return
if __name__ == "__main__":
app.run()