Fixes compatibility with Werkzeug 2.1.0 ported over from flask-restx#423. https://github.com/python-restx/flask-restx/pull/423 diff --git a/flask_restful/reqparse.py b/flask_restful/reqparse.py index 9bb3099..5c59594 100644 --- a/flask_restful/reqparse.py +++ b/flask_restful/reqparse.py @@ -114,7 +114,10 @@ class Argument(object): :param request: The flask request object to parse arguments from """ if isinstance(self.location, six.string_types): - value = getattr(request, self.location, MultiDict()) + if self.location in {"json", "get_json"}: + value = request.get_json(silent=True) + else: + value = getattr(request, self.location, MultiDict()) if callable(value): value = value() if value is not None: @@ -122,7 +125,10 @@ class Argument(object): else: values = MultiDict() for l in self.location: - value = getattr(request, l, None) + if l in {"json", "get_json"}: + value = request.get_json(silent=True) + else: + value = getattr(request, l, None) if callable(value): value = value() if value is not None: diff --git a/tests/test_api.py b/tests/test_api.py index 15f12eb..9a9cceb 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -936,7 +936,7 @@ class APITestCase(unittest.TestCase): app = app.test_client() resp = app.get('/api') self.assertEqual(resp.status_code, 302) - self.assertEqual(resp.headers['Location'], 'http://localhost/') + self.assertEqual(resp.headers['Location'], '/') def test_json_float_marshalled(self): app = Flask(__name__) diff --git a/tests/test_reqparse.py b/tests/test_reqparse.py index 1d75e40..e5c586b 100644 --- a/tests/test_reqparse.py +++ b/tests/test_reqparse.py @@ -23,8 +23,9 @@ class ReqParseTestCase(unittest.TestCase): with app.app_context(): parser = RequestParser() parser.add_argument('foo', choices=('one', 'two'), help='Bad choice: {error_msg}') - req = Mock(['values']) + req = Mock(["values", "get_json"]) req.values = MultiDict([('foo', 'three')]) + req.get_json.return_value = None parser.parse_args(req) expected = {'foo': 'Bad choice: three is not a valid choice'} abort.assert_called_with(400, message=expected) @@ -35,8 +36,9 @@ class ReqParseTestCase(unittest.TestCase): with app.app_context(): parser = RequestParser() parser.add_argument('foo', choices=('one', 'two'), help=u'Bad choice: {error_msg}') - req = Mock(['values']) + req = Mock(["values", "get_json"]) req.values = MultiDict([('foo', u'\xf0\x9f\x8d\x95')]) + req.get_json.return_value = None parser.parse_args(req) expected = {'foo': u'Bad choice: \xf0\x9f\x8d\x95 is not a valid choice'} abort.assert_called_with(400, message=expected) @@ -47,8 +49,9 @@ class ReqParseTestCase(unittest.TestCase): with app.app_context(): parser = RequestParser() parser.add_argument('foo', choices=['one', 'two'], help='Please select a valid choice') - req = Mock(['values']) + req = Mock(["values", "get_json"]) req.values = MultiDict([('foo', 'three')]) + req.get_json.return_value = None parser.parse_args(req) expected = {'foo': 'Please select a valid choice'} abort.assert_called_with(400, message=expected) @@ -58,8 +61,9 @@ class ReqParseTestCase(unittest.TestCase): def bad_choice(): parser = RequestParser() parser.add_argument('foo', choices=['one', 'two']) - req = Mock(['values']) + req = Mock(["values", "get_json"]) req.values = MultiDict([('foo', 'three')]) + req.get_json.return_value = None parser.parse_args(req) abort.assert_called_with(400, message='three is not a valid choice') app = Flask(__name__) @@ -190,7 +194,8 @@ class ReqParseTestCase(unittest.TestCase): self.assertTrue(len(arg.source(req)) == 0) # yes, basically you don't find it def test_source_default_location(self): - req = Mock(['values']) + req = Mock(['values', 'get_json']) + req.get_json.return_value = None req._get_child_mock = lambda **kwargs: MultiDict() arg = Argument('foo') self.assertEqual(arg.source(req), req.values) @@ -215,8 +220,9 @@ class ReqParseTestCase(unittest.TestCase): args = parser.parse_args(req) self.assertEqual(args['foo'], "bar") - req = Mock() + req = Mock(['get_json']) req.values = () + req.get_json.return_value = None req.json = None req.view_args = {"foo": "bar"} parser = RequestParser()