init
This commit is contained in:
commit
0e58d36bcb
4 changed files with 115 additions and 0 deletions
1
.envrc
Normal file
1
.envrc
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
use nix
|
||||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
||||||
|
unsloth_compiled_cache
|
||||||
|
outputs_qwen35
|
||||||
21
shell.nix
Normal file
21
shell.nix
Normal file
|
|
@ -0,0 +1,21 @@
|
||||||
|
{nixpkgs ? import <nixpkgs> {}}:
|
||||||
|
nixpkgs.mkShell {
|
||||||
|
nativeBuildInputs = with nixpkgs; [
|
||||||
|
ruff
|
||||||
|
python3
|
||||||
|
];
|
||||||
|
|
||||||
|
LD_LIBRARY_PATH = "${nixpkgs.stdenv.cc.cc.lib}/lib:${nixpkgs.zstd.out}/lib:${nixpkgs.zlib.out}/lib";
|
||||||
|
shellHook = ''
|
||||||
|
if [[ ! -d ".venv" ]]; then
|
||||||
|
python -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.2/ --upgrade --force-reinstall
|
||||||
|
pip install "unsloth[amd] @ git+https://github.com/unslothai/unsloth"
|
||||||
|
pip install "unsloth-zoo[main] @ git+https://github.com/unslothai/unsloth-zoo"
|
||||||
|
pip install marimo ipython
|
||||||
|
else
|
||||||
|
source .venv/bin/activate
|
||||||
|
fi
|
||||||
|
'';
|
||||||
|
}
|
||||||
91
train.py
Normal file
91
train.py
Normal file
|
|
@ -0,0 +1,91 @@
|
||||||
|
import marimo
|
||||||
|
|
||||||
|
__generated_with = "0.21.1"
|
||||||
|
app = marimo.App(width="full")
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _():
|
||||||
|
from unsloth import FastLanguageModel
|
||||||
|
import torch
|
||||||
|
from datasets import load_dataset
|
||||||
|
from trl import SFTTrainer, SFTConfig
|
||||||
|
|
||||||
|
return FastLanguageModel, SFTConfig, SFTTrainer, load_dataset
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(load_dataset):
|
||||||
|
max_seq_length = 2048 # start small; scale up after it works
|
||||||
|
|
||||||
|
# Example dataset (replace with yours). Needs a "text" column.
|
||||||
|
url = "https://huggingface.co/datasets/laion/OIG/resolve/main/unified_chip2.jsonl"
|
||||||
|
dataset = load_dataset("json", data_files={"train": url}, split="train")
|
||||||
|
return dataset, max_seq_length
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(dataset):
|
||||||
|
dataset.to_pandas().head(50)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(FastLanguageModel, max_seq_length):
|
||||||
|
# unsloth/Qwen3.5-0.8B
|
||||||
|
model, tokenizer = FastLanguageModel.from_pretrained(
|
||||||
|
model_name = "unsloth/Qwen3.5-0.8B",
|
||||||
|
max_seq_length = max_seq_length,
|
||||||
|
load_in_4bit = False, # MoE QLoRA not recommended, dense 27B is fine
|
||||||
|
load_in_16bit = True, # bf16/16-bit LoRA
|
||||||
|
full_finetuning = False,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = FastLanguageModel.get_peft_model(
|
||||||
|
model,
|
||||||
|
r = 16,
|
||||||
|
target_modules = [
|
||||||
|
"q_proj", "k_proj", "v_proj", "o_proj",
|
||||||
|
"gate_proj", "up_proj", "down_proj",
|
||||||
|
],
|
||||||
|
lora_alpha = 16,
|
||||||
|
lora_dropout = 0,
|
||||||
|
bias = "none",
|
||||||
|
# "unsloth" checkpointing is intended for very long context + lower VRAM
|
||||||
|
use_gradient_checkpointing = "unsloth",
|
||||||
|
random_state = 3407,
|
||||||
|
max_seq_length = max_seq_length,
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(SFTConfig, SFTTrainer, dataset, max_seq_length, model, tokenizer):
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model = model,
|
||||||
|
train_dataset = dataset,
|
||||||
|
tokenizer = tokenizer,
|
||||||
|
args = SFTConfig(
|
||||||
|
max_seq_length = max_seq_length,
|
||||||
|
per_device_train_batch_size = 1,
|
||||||
|
gradient_accumulation_steps = 4,
|
||||||
|
warmup_steps = 10,
|
||||||
|
max_steps = 100,
|
||||||
|
logging_steps = 1,
|
||||||
|
output_dir = "outputs_qwen35",
|
||||||
|
optim = "adamw_8bit",
|
||||||
|
seed = 3407,
|
||||||
|
dataset_num_proc = 1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return (trainer,)
|
||||||
|
|
||||||
|
|
||||||
|
@app.cell
|
||||||
|
def _(trainer):
|
||||||
|
trainer.train()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app.run()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue