commit 0e58d36bcb7967ab3d6eefdc29fbbaabb33b7cbe Author: Alan Date: Wed Mar 25 18:24:39 2026 +1100 init diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..1d953f4 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use nix diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..90eafaf --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +unsloth_compiled_cache +outputs_qwen35 diff --git a/shell.nix b/shell.nix new file mode 100644 index 0000000..c95210a --- /dev/null +++ b/shell.nix @@ -0,0 +1,21 @@ +{nixpkgs ? import {}}: +nixpkgs.mkShell { + nativeBuildInputs = with nixpkgs; [ + ruff + python3 + ]; + + LD_LIBRARY_PATH = "${nixpkgs.stdenv.cc.cc.lib}/lib:${nixpkgs.zstd.out}/lib:${nixpkgs.zlib.out}/lib"; + shellHook = '' + if [[ ! -d ".venv" ]]; then + python -m venv .venv + source .venv/bin/activate + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2/ --upgrade --force-reinstall + pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth" + pip install "unsloth-zoo[main] @ git+https://github.com/unslothai/unsloth-zoo" + pip install marimo ipython + else + source .venv/bin/activate + fi + ''; +} diff --git a/train.py b/train.py new file mode 100644 index 0000000..8b1aad1 --- /dev/null +++ b/train.py @@ -0,0 +1,91 @@ +import marimo + +__generated_with = "0.21.1" +app = marimo.App(width="full") + + +@app.cell +def _(): + from unsloth import FastLanguageModel + import torch + from datasets import load_dataset + from trl import SFTTrainer, SFTConfig + + return FastLanguageModel, SFTConfig, SFTTrainer, load_dataset + + +@app.cell +def _(load_dataset): + max_seq_length = 2048 # start small; scale up after it works + + # Example dataset (replace with yours). Needs a "text" column. + url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl" + dataset = load_dataset("json", data_files={"train": url}, split="train") + return dataset, max_seq_length + + +@app.cell +def _(dataset): + dataset.to_pandas().head(50) + return + + +@app.cell +def _(FastLanguageModel, max_seq_length): + # unsloth/Qwen3.5-0.8B + model, tokenizer = FastLanguageModel.from_pretrained( + model_name = "unsloth/Qwen3.5-0.8B", + max_seq_length = max_seq_length, + load_in_4bit = False, # MoE QLoRA not recommended, dense 27B is fine + load_in_16bit = True, # bf16/16-bit LoRA + full_finetuning = False, + ) + + model = FastLanguageModel.get_peft_model( + model, + r = 16, + target_modules = [ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + ], + lora_alpha = 16, + lora_dropout = 0, + bias = "none", + # "unsloth" checkpointing is intended for very long context + lower VRAM + use_gradient_checkpointing = "unsloth", + random_state = 3407, + max_seq_length = max_seq_length, + ) + return model, tokenizer + + +@app.cell +def _(SFTConfig, SFTTrainer, dataset, max_seq_length, model, tokenizer): + trainer = SFTTrainer( + model = model, + train_dataset = dataset, + tokenizer = tokenizer, + args = SFTConfig( + max_seq_length = max_seq_length, + per_device_train_batch_size = 1, + gradient_accumulation_steps = 4, + warmup_steps = 10, + max_steps = 100, + logging_steps = 1, + output_dir = "outputs_qwen35", + optim = "adamw_8bit", + seed = 3407, + dataset_num_proc = 1, + ), + ) + return (trainer,) + + +@app.cell +def _(trainer): + trainer.train() + return + + +if __name__ == "__main__": + app.run()