85 lines
2.3 KiB
Nix
85 lines
2.3 KiB
Nix
{
|
|
description = "Minimal example training a Qwen3.5 0.8B lora on AMD RYZEN AI MAX+ 395 APU";
|
|
|
|
inputs = {
|
|
nixpkgs.url = "github:nixos/nixpkgs?ref=nixos-unstable";
|
|
flake-utils.url = "github:numtide/flake-utils";
|
|
};
|
|
|
|
outputs = {
|
|
self,
|
|
nixpkgs,
|
|
flake-utils,
|
|
}:
|
|
flake-utils.lib.eachDefaultSystem (
|
|
system: let
|
|
pkgs = import nixpkgs {
|
|
inherit system;
|
|
config = {
|
|
allowUnfree = true;
|
|
rocmSupport = true;
|
|
};
|
|
};
|
|
|
|
python = pkgs.python313.override {
|
|
packageOverrides = self: super: {
|
|
datasets = super.datasets.overridePythonAttrs (oldAttrs: rec {
|
|
version = "4.3.0";
|
|
src = pkgs.fetchFromGitHub {
|
|
owner = "huggingface";
|
|
repo = "datasets";
|
|
tag = version;
|
|
hash = "sha256-3rDSHAMwoe9CkRLs3PDXSw2ONUrUWyBSZFpzk2C1A3A=";
|
|
};
|
|
});
|
|
unsloth-zoo = super.unsloth-zoo.overridePythonAttrs (oldAttrs: {
|
|
pythonImportsCheck = [];
|
|
});
|
|
unsloth = super.unsloth.overridePythonAttrs (oldAttrs: {
|
|
dependencies =
|
|
oldAttrs.dependencies
|
|
++ [
|
|
super.pydantic
|
|
super.nest-asyncio
|
|
];
|
|
pythonRelaxDeps =
|
|
oldAttrs.pythonRelaxDeps
|
|
++ [
|
|
"trl"
|
|
];
|
|
postPatch = ''
|
|
# Relax setuptools version constraint in pyproject.toml
|
|
sed -i 's/setuptools==80\.9\.0/setuptools>=80.9/g' pyproject.toml || true
|
|
|
|
# Relax setuptools-scm version constraint in pyproject.toml
|
|
sed -i 's/setuptools-scm==9\.2\.0/setuptools-scm>=9.2/g' pyproject.toml || true
|
|
'';
|
|
});
|
|
};
|
|
};
|
|
|
|
pythonEnv = python.withPackages (ps:
|
|
with ps; [
|
|
## Base ML/DS libs
|
|
torch
|
|
torchvision
|
|
torchaudio
|
|
transformers
|
|
accelerate
|
|
|
|
## Unsloth
|
|
unsloth
|
|
|
|
## QOL
|
|
marimo
|
|
ipython
|
|
]);
|
|
in {
|
|
devShell = pkgs.mkShell {
|
|
packages = [
|
|
pythonEnv
|
|
];
|
|
};
|
|
}
|
|
);
|
|
}
|