diff options
Diffstat (limited to 'pkgs/development/python-modules/torchvision/default.nix')
-rw-r--r-- | pkgs/development/python-modules/torchvision/default.nix | 24 |
1 files changed, 21 insertions, 3 deletions
diff --git a/pkgs/development/python-modules/torchvision/default.nix b/pkgs/development/python-modules/torchvision/default.nix index a42c517ede967..fc9905881cb6a 100644 --- a/pkgs/development/python-modules/torchvision/default.nix +++ b/pkgs/development/python-modules/torchvision/default.nix @@ -1,4 +1,5 @@ { lib +, symlinkJoin , buildPythonPackage , fetchFromGitHub , ninja @@ -10,9 +11,18 @@ , pillow , pytorch , pytest +, cudatoolkit +, cudnn +, cudaSupport ? pytorch.cudaSupport or false # by default uses the value from pytorch }: -buildPythonPackage rec { +let + cudatoolkit_joined = symlinkJoin { + name = "${cudatoolkit.name}-unsplit"; + paths = [ cudatoolkit.out cudatoolkit.lib ]; + }; + cudaArchStr = lib.optionalString cudaSupport lib.strings.concatStringsSep ";" pytorch.cudaArchList; +in buildPythonPackage rec { pname = "torchvision"; version = "0.10.0"; @@ -23,15 +33,22 @@ buildPythonPackage rec { sha256 = "13j04ij0jmi58nhav1p69xrm8dg7jisg23268i3n6lnms37n02kc"; }; - nativeBuildInputs = [ libpng ninja which ]; + nativeBuildInputs = [ libpng ninja which ] + ++ lib.optionals cudaSupport [ cudatoolkit_joined ]; TORCHVISION_INCLUDE = "${libjpeg_turbo.dev}/include/"; TORCHVISION_LIBRARY = "${libjpeg_turbo}/lib/"; - buildInputs = [ libjpeg_turbo libpng ]; + buildInputs = [ libjpeg_turbo libpng ] + ++ lib.optionals cudaSupport [ cudnn ]; propagatedBuildInputs = [ numpy pillow pytorch scipy ]; + preBuild = lib.optionalString cudaSupport '' + export TORCH_CUDA_ARCH_LIST="${cudaArchStr}" + export FORCE_CUDA=1 + ''; + # tries to download many datasets for tests doCheck = false; @@ -45,6 +62,7 @@ buildPythonPackage rec { description = "PyTorch vision library"; homepage = "https://pytorch.org/"; license = licenses.bsd3; + platforms = with platforms; linux ++ lib.optionals (!cudaSupport) darwin; maintainers = with maintainers; [ ericsagnes ]; }; } |