about summary refs log tree commit diff
path: root/pkgs/development/python-modules/torchrl/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/torchrl/default.nix')
-rw-r--r--pkgs/development/python-modules/torchrl/default.nix121
1 files changed, 72 insertions, 49 deletions
diff --git a/pkgs/development/python-modules/torchrl/default.nix b/pkgs/development/python-modules/torchrl/default.nix
index 591e59302ea6a..6951192f52c48 100644
--- a/pkgs/development/python-modules/torchrl/default.nix
+++ b/pkgs/development/python-modules/torchrl/default.nix
@@ -1,37 +1,48 @@
-{ lib
-, buildPythonPackage
-, pythonOlder
-, fetchFromGitHub
-, ninja
-, setuptools
-, wheel
-, which
-, cloudpickle
-, numpy
-, torch
-, ale-py
-, gym
-, pygame
-, gymnasium
-, mujoco
-, moviepy
-, git
-, hydra-core
-, tensorboard
-, tqdm
-, wandb
-, packaging
-, tensordict
-, imageio
-, pytest-rerunfailures
-, pytestCheckHook
-, pyyaml
-, scipy
+{
+  lib,
+  buildPythonPackage,
+  pythonOlder,
+  fetchFromGitHub,
+  ninja,
+  setuptools,
+  wheel,
+  which,
+  cloudpickle,
+  numpy,
+  torch,
+  ale-py,
+  gym,
+  pygame,
+  torchsnapshot,
+  gymnasium,
+  mujoco,
+  h5py,
+  huggingface-hub,
+  minari,
+  pandas,
+  pillow,
+  requests,
+  scikit-learn,
+  torchvision,
+  tqdm,
+  moviepy,
+  git,
+  hydra-core,
+  tensorboard,
+  wandb,
+  packaging,
+  tensordict,
+  imageio,
+  pytest-rerunfailures,
+  pytestCheckHook,
+  pyyaml,
+  scipy,
+  stdenv,
 }:
 
 buildPythonPackage rec {
   pname = "torchrl";
-  version = "0.3.1";
+  version = "0.4.0";
   pyproject = true;
 
   disabled = pythonOlder "3.8";
@@ -40,17 +51,17 @@ buildPythonPackage rec {
     owner = "pytorch";
     repo = "rl";
     rev = "refs/tags/v${version}";
-    hash = "sha256-lETW996IKPUGgZpe+cyzrXvVmDSwj5G4XFreFmGxReQ=";
+    hash = "sha256-8wSyyErqveP9zZS/UGvWVBYyylu9BuA447GEjXIzBIk=";
   };
 
-  nativeBuildInputs = [
+  build-system = [
     ninja
     setuptools
     wheel
     which
   ];
 
-  propagatedBuildInputs = [
+  dependencies = [
     cloudpickle
     numpy
     packaging
@@ -64,13 +75,23 @@ buildPythonPackage rec {
       gym
       pygame
     ];
+    checkpointing = [ torchsnapshot ];
     gym-continuous = [
       gymnasium
       mujoco
     ];
-    rendering = [
-      moviepy
+    offline-data = [
+      h5py
+      huggingface-hub
+      minari
+      pandas
+      pillow
+      requests
+      scikit-learn
+      torchvision
+      tqdm
     ];
+    rendering = [ moviepy ];
     utils = [
       git
       hydra-core
@@ -85,9 +106,7 @@ buildPythonPackage rec {
     export D4RL_DATASET_DIR=$(mktemp -d)
   '';
 
-  pythonImportsCheck = [
-    "torchrl"
-  ];
+  pythonImportsCheck = [ "torchrl" ];
 
   # We have to delete the source because otherwise it is used instead of the installed package.
   preCheck = ''
@@ -96,17 +115,19 @@ buildPythonPackage rec {
     export XDG_RUNTIME_DIR=$(mktemp -d)
   '';
 
-  nativeCheckInputs = [
-    gymnasium
-    imageio
-    pytest-rerunfailures
-    pytestCheckHook
-    pyyaml
-    scipy
-  ]
-  ++ passthru.optional-dependencies.atari
-  ++ passthru.optional-dependencies.gym-continuous
-  ++ passthru.optional-dependencies.rendering;
+  nativeCheckInputs =
+    [
+      gymnasium
+      imageio
+      pytest-rerunfailures
+      pytestCheckHook
+      pyyaml
+      scipy
+      torchvision
+    ]
+    ++ passthru.optional-dependencies.atari
+    ++ passthru.optional-dependencies.gym-continuous
+    ++ passthru.optional-dependencies.rendering;
 
   disabledTests = [
     # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called
@@ -137,5 +158,7 @@ buildPythonPackage rec {
     changelog = "https://github.com/pytorch/rl/releases/tag/v${version}";
     license = licenses.mit;
     maintainers = with maintainers; [ GaetanLepage ];
+    # ~3k tests fail with: RuntimeError: internal error
+    broken = stdenv.isLinux && stdenv.isAarch64;
   };
 }