diff options
Diffstat (limited to 'pkgs/development/python-modules/torchrl/default.nix')
-rw-r--r-- | pkgs/development/python-modules/torchrl/default.nix | 127 |
1 files changed, 77 insertions, 50 deletions
diff --git a/pkgs/development/python-modules/torchrl/default.nix b/pkgs/development/python-modules/torchrl/default.nix index 591e59302ea6a..e075696881104 100644 --- a/pkgs/development/python-modules/torchrl/default.nix +++ b/pkgs/development/python-modules/torchrl/default.nix @@ -1,37 +1,48 @@ -{ lib -, buildPythonPackage -, pythonOlder -, fetchFromGitHub -, ninja -, setuptools -, wheel -, which -, cloudpickle -, numpy -, torch -, ale-py -, gym -, pygame -, gymnasium -, mujoco -, moviepy -, git -, hydra-core -, tensorboard -, tqdm -, wandb -, packaging -, tensordict -, imageio -, pytest-rerunfailures -, pytestCheckHook -, pyyaml -, scipy +{ + lib, + buildPythonPackage, + pythonOlder, + fetchFromGitHub, + ninja, + setuptools, + wheel, + which, + cloudpickle, + numpy, + torch, + ale-py, + gym, + pygame, + torchsnapshot, + gymnasium, + mujoco, + h5py, + huggingface-hub, + minari, + pandas, + pillow, + requests, + scikit-learn, + torchvision, + tqdm, + moviepy, + git, + hydra-core, + tensorboard, + wandb, + packaging, + tensordict, + imageio, + pytest-rerunfailures, + pytestCheckHook, + pyyaml, + scipy, + stdenv, }: buildPythonPackage rec { pname = "torchrl"; - version = "0.3.1"; + version = "0.4.0"; pyproject = true; disabled = pythonOlder "3.8"; @@ -40,17 +51,17 @@ buildPythonPackage rec { owner = "pytorch"; repo = "rl"; rev = "refs/tags/v${version}"; - hash = "sha256-lETW996IKPUGgZpe+cyzrXvVmDSwj5G4XFreFmGxReQ="; + hash = "sha256-8wSyyErqveP9zZS/UGvWVBYyylu9BuA447GEjXIzBIk="; }; - nativeBuildInputs = [ + build-system = [ ninja setuptools wheel which ]; - propagatedBuildInputs = [ + dependencies = [ cloudpickle numpy packaging @@ -64,13 +75,23 @@ buildPythonPackage rec { gym pygame ]; + checkpointing = [ torchsnapshot ]; gym-continuous = [ gymnasium mujoco ]; - rendering = [ - moviepy + offline-data = [ + h5py + huggingface-hub + minari + pandas + pillow + requests + scikit-learn + torchvision + tqdm ]; + rendering = [ moviepy ]; utils = [ git hydra-core @@ -85,9 +106,7 @@ buildPythonPackage rec { export D4RL_DATASET_DIR=$(mktemp -d) ''; - pythonImportsCheck = [ - "torchrl" - ]; + pythonImportsCheck = [ "torchrl" ]; # We have to delete the source because otherwise it is used instead of the installed package. preCheck = '' @@ -96,17 +115,19 @@ buildPythonPackage rec { export XDG_RUNTIME_DIR=$(mktemp -d) ''; - nativeCheckInputs = [ - gymnasium - imageio - pytest-rerunfailures - pytestCheckHook - pyyaml - scipy - ] - ++ passthru.optional-dependencies.atari - ++ passthru.optional-dependencies.gym-continuous - ++ passthru.optional-dependencies.rendering; + nativeCheckInputs = + [ + gymnasium + imageio + pytest-rerunfailures + pytestCheckHook + pyyaml + scipy + torchvision + ] + ++ passthru.optional-dependencies.atari + ++ passthru.optional-dependencies.gym-continuous + ++ passthru.optional-dependencies.rendering; disabledTests = [ # mujoco.FatalError: an OpenGL platform library has not been loaded into this process, this most likely means that a valid OpenGL context has not been created before mjr_makeContext was called @@ -129,13 +150,19 @@ buildPythonPackage rec { "test_trans_parallel_env_check" "test_trans_serial_env_check" "test_transform_env" + + # undeterministic + "test_distributed_collector_updatepolicy" + "test_timeit" ]; meta = with lib; { - description = "A modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; + description = "Modular, primitive-first, python-first PyTorch library for Reinforcement Learning"; homepage = "https://github.com/pytorch/rl"; changelog = "https://github.com/pytorch/rl/releases/tag/v${version}"; license = licenses.mit; maintainers = with maintainers; [ GaetanLepage ]; + # ~3k tests fail with: RuntimeError: internal error + broken = stdenv.isLinux && stdenv.isAarch64; }; } |