summary refs log tree commit diff
path: root/pkgs/development/python-modules/shap/default.nix
blob: 96c9c1f952acd453010ae1704a772a9b818be195 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
{ lib
, buildPythonPackage
, fetchFromGitHub
, writeText
, isPy27
, pytestCheckHook
, pytest-mpl
, numpy
, scipy
, scikit-learn
, pandas
, transformers
, opencv4
, lightgbm
, catboost
, pyspark
, sentencepiece
, tqdm
, slicer
, numba
, matplotlib
, nose
, lime
, cloudpickle
, ipython
}:

buildPythonPackage rec {
  pname = "shap";
  version = "0.41.0";
  disabled = isPy27;

  src = fetchFromGitHub {
    owner = "slundberg";
    repo = pname;
    rev = "refs/tags/v${version}";
    hash = "sha256-rYVWQ3VRvIObSQPwDRsxhTOGOKNkYkLtiHzVwoB3iJ0=";
  };

  propagatedBuildInputs = [
    numpy
    scipy
    scikit-learn
    pandas
    tqdm
    slicer
    numba
    cloudpickle
  ];

  passthru.optional-dependencies = {
    plots = [ matplotlib ipython ];
    others = [ lime ];
  };

  preCheck = let
    # This pytest hook mocks and catches attempts at accessing the network
    # tests that try to access the network will raise, get caught, be marked as skipped and tagged as xfailed.
    conftestSkipNetworkErrors = writeText "conftest.py" ''
      from _pytest.runner import pytest_runtest_makereport as orig_pytest_runtest_makereport
      import urllib, requests

      class NetworkAccessDeniedError(RuntimeError): pass
      def deny_network_access(*a, **kw):
        raise NetworkAccessDeniedError

      requests.head = deny_network_access
      requests.get  = deny_network_access
      urllib.request.urlopen = deny_network_access
      urllib.request.Request = deny_network_access

      def pytest_runtest_makereport(item, call):
        tr = orig_pytest_runtest_makereport(item, call)
        if call.excinfo is not None and call.excinfo.type is NetworkAccessDeniedError:
            tr.outcome = 'skipped'
            tr.wasxfail = "reason: Requires network access."
        return tr
    '';
  in ''
    export HOME=$TMPDIR
    # when importing the local copy the extension is not found
    rm -r shap

    # coverage testing is a waste considering how much we have to skip
    substituteInPlace pytest.ini \
      --replace "--cov=shap --cov-report=term-missing" ""

    # Add pytest hook skipping tests that access network.
    # These tests are marked as "Expected fail" (xfail)
    cat ${conftestSkipNetworkErrors} >> tests/conftest.py
  '';
  nativeCheckInputs = [
    pytestCheckHook
    pytest-mpl
    matplotlib
    nose
    ipython
    # optional dependencies, which only serve to enable more tests:
    opencv4
    #pytorch # we already skip all its tests due to slowness, adding it does nothing
    transformers
    #xgboost # numerically unstable? xgboost tests randomly fails pending on nixpkgs revision
    lightgbm
    catboost
    pyspark
    sentencepiece
  ];
  disabledTestPaths = [
    # takes forever without GPU acceleration
    "tests/explainers/test_deep.py"
    "tests/explainers/test_gradient.py"
    # requires GPU. We skip here instead of having pytest repeatedly check for GPU
    "tests/explainers/test_gpu_tree.py"
    # The resulting plots look sane, but does not match pixel-perfectly with the baseline.
    # Likely due to a matplotlib version mismatch, different backend, or due to missing fonts.
    "tests/plots/test_summary.py" # FIXME: enable
    # 100% of the tests in these paths require network
    "tests/explainers/test_explainer.py"
    "tests/explainers/test_exact.py"
    "tests/explainers/test_partition.py"
    "tests/maskers/test_fixed_composite.py"
    "tests/maskers/test_text.py"
    "tests/models/test_teacher_forcing_logits.py"
    "tests/models/test_text_generation.py"
  ];
  disabledTests = [
    # unstable. A xgboost-enabled test. possibly related: https://github.com/slundberg/shap/issues/2480
    "test_provided_background_tree_path_dependent"
  ];

  #pytestFlagsArray = ["-x" "-W" "ignore"]; # uncomment this to debug

  pythonImportsCheck = [
    "shap"
    "shap.explainers"
    "shap.explainers.other"
    "shap.plots"
    "shap.plots.colors"
    "shap.benchmark"
    "shap.maskers"
    "shap.utils"
    "shap.actions"
    "shap.models"
  ];

  meta = with lib; {
    description = "A unified approach to explain the output of any machine learning model";
    homepage = "https://github.com/slundberg/shap";
    changelog = "https://github.com/slundberg/shap/releases/tag/v${version}";
    license = licenses.mit;
    maintainers = with maintainers; [ evax ];
    platforms = platforms.unix;
    # No support for scikit-learn > 1.2
    # https://github.com/slundberg/shap/issues/2866
    broken = true;
  };
}