import marimo __generated_with = "0.21.1" app = marimo.App(width="full") @app.cell def _(): from unsloth import FastLanguageModel import torch from datasets import load_dataset from trl import SFTTrainer, SFTConfig return FastLanguageModel, SFTConfig, SFTTrainer, load_dataset @app.cell def _(load_dataset): max_seq_length = 2048 # start small; scale up after it works # Example dataset (replace with yours). Needs a "text" column. url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl" dataset = load_dataset("json", data_files={"train": url}, split="train") return dataset, max_seq_length @app.cell def _(dataset): dataset.to_pandas().head(50) return @app.cell def _(FastLanguageModel, max_seq_length): # unsloth/Qwen3.5-0.8B model, tokenizer = FastLanguageModel.from_pretrained( model_name = "unsloth/Qwen3.5-0.8B", max_seq_length = max_seq_length, load_in_4bit = False, # MoE QLoRA not recommended, dense 27B is fine load_in_16bit = True, # bf16/16-bit LoRA full_finetuning = False, ) model = FastLanguageModel.get_peft_model( model, r = 16, target_modules = [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha = 16, lora_dropout = 0, bias = "none", # "unsloth" checkpointing is intended for very long context + lower VRAM use_gradient_checkpointing = "unsloth", random_state = 3407, max_seq_length = max_seq_length, ) return model, tokenizer @app.cell def _(SFTConfig, SFTTrainer, dataset, max_seq_length, model, tokenizer): trainer = SFTTrainer( model = model, train_dataset = dataset, tokenizer = tokenizer, args = SFTConfig( max_seq_length = max_seq_length, per_device_train_batch_size = 1, gradient_accumulation_steps = 4, warmup_steps = 10, max_steps = 100, logging_steps = 1, output_dir = "outputs_qwen35", optim = "adamw_8bit", seed = 3407, dataset_num_proc = 1, ), ) return (trainer,) @app.cell def _(trainer): trainer.train() return if __name__ == "__main__": app.run()