about summary refs log tree commit diff
path: root/pkgs/development/python-modules/pytorch-metric-learning/default.nix
diff options
context:
space:
mode:
Diffstat (limited to 'pkgs/development/python-modules/pytorch-metric-learning/default.nix')
-rw-r--r--pkgs/development/python-modules/pytorch-metric-learning/default.nix86
1 files changed, 60 insertions, 26 deletions
diff --git a/pkgs/development/python-modules/pytorch-metric-learning/default.nix b/pkgs/development/python-modules/pytorch-metric-learning/default.nix
index ec57a02f0acd4..049064754937b 100644
--- a/pkgs/development/python-modules/pytorch-metric-learning/default.nix
+++ b/pkgs/development/python-modules/pytorch-metric-learning/default.nix
@@ -4,19 +4,31 @@
   buildPythonPackage,
   fetchFromGitHub,
   isPy27,
+  config,
+
+  # build-system
+  setuptools,
+
+  # dependencies
   numpy,
   scikit-learn,
-  pytestCheckHook,
   torch,
-  torchvision,
   tqdm,
+
+  # optional-dependencies
   faiss,
+  tensorboard,
+
+  # tests
+  cudaSupport ? config.cudaSupport,
+  pytestCheckHook,
+  torchvision
 }:
 
 buildPythonPackage rec {
   pname = "pytorch-metric-learning";
   version = "2.5.0";
-  format = "setuptools";
+  pyproject = true;
 
   disabled = isPy27;
 
@@ -27,14 +39,30 @@ buildPythonPackage rec {
     hash = "sha256-1y7VCnzgwFOMeMloVdYyszNhf/zZlBJUjuF4qgA5c0A=";
   };
 
-  propagatedBuildInputs = [
+  build-system = [
+    setuptools
+  ];
+
+  dependencies = [
     numpy
     torch
     scikit-learn
-    torchvision
     tqdm
   ];
 
+  optional-dependencies = {
+    with-hooks = [
+      # TODO: record-keeper
+      faiss
+      tensorboard
+    ];
+    with-hooks-cpu = [
+      # TODO: record-keeper
+      faiss
+      tensorboard
+    ];
+  };
+
   preCheck = ''
     export HOME=$TMP
     export TEST_DEVICE=cpu
@@ -43,29 +71,35 @@ buildPythonPackage rec {
 
   # package only requires `unittest`, but use `pytest` to exclude tests
   nativeCheckInputs = [
-    faiss
     pytestCheckHook
-  ];
+    torchvision
+  ] ++ lib.flatten (lib.attrValues optional-dependencies);
 
-  disabledTests =
-    [
-      # TypeError: setup() missing 1 required positional argument: 'world_size'
-      "TestDistributedLossWrapper"
-      # require network access:
-      "TestInference"
-      "test_get_nearest_neighbors"
-      "test_tuplestoweights_sampler"
-      "test_untrained_indexer"
-      "test_metric_loss_only"
-      "test_pca"
-      # flaky
-      "test_distributed_classifier_loss_and_miner"
-    ]
-    ++ lib.optionals (stdenv.isLinux && stdenv.isAarch64) [
-      # RuntimeError: DataLoader worker (pid(s) <...>) exited unexpectedly
-      "test_global_embedding_space_tester"
-      "test_with_same_parent_label_tester"
-    ];
+  disabledTests = [
+    # network access
+    "test_tuplestoweights_sampler"
+    "test_metric_loss_only"
+    "test_add_to_indexer"
+    "test_get_nearest_neighbors"
+    "test_list_of_text"
+    "test_untrained_indexer"
+  ] ++ lib.optionals stdenv.isDarwin [
+    # AttributeError: module 'torch.distributed' has no attribute 'init_process_group'
+    "test_single_proc"
+  ] ++ lib.optionals cudaSupport [
+    # crashes with SIGBART
+    "test_accuracy_calculator_and_faiss_with_torch_and_numpy"
+    "test_accuracy_calculator_large_k"
+    "test_custom_knn"
+    "test_global_embedding_space_tester"
+    "test_global_two_stream_embedding_space_tester"
+    "test_index_type"
+    "test_k_warning"
+    "test_many_tied_distances"
+    "test_query_within_reference"
+    "test_tied_distances"
+    "test_with_same_parent_label_tester"
+  ];
 
   meta = {
     description = "Metric learning library for PyTorch";