diff --git a/plugins/wasm-cpp/extensions/key_auth/plugin.cc b/plugins/wasm-cpp/extensions/key_auth/plugin.cc index a88b6bda7..ad111760e 100644 --- a/plugins/wasm-cpp/extensions/key_auth/plugin.cc +++ b/plugins/wasm-cpp/extensions/key_auth/plugin.cc @@ -42,10 +42,7 @@ static RegisterContextFactory register_KeyAuth(CONTEXT_FACTORY(PluginContext), namespace { -void deniedNoKeyAuthData(const std::string& realm) { - sendLocalResponse(401, "No API key found in request", "", - {{"WWW-Authenticate", absl::StrCat("Key realm=", realm)}}); -} +const std::string OriginalAuthKey("X-HI-ORIGINAL-AUTH"); void deniedInvalidCredentials(const std::string& realm) { sendLocalResponse(401, "Request denied by Key Auth check. Invalid API key", @@ -84,6 +81,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration, } if (!JsonArrayIterate( configuration, "consumers", [&](const json& consumer) -> bool { + Consumer c; auto item = consumer.find("name"); if (item == consumer.end()) { LOG_WARN("can't find 'name' field in consumer."); @@ -94,6 +92,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration, !name.first) { return false; } + c.name = name.first.value(); item = consumer.find("credential"); if (item == consumer.end()) { LOG_WARN("can't find 'credential' field in consumer."); @@ -104,6 +103,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration, !credential.first) { return false; } + c.credential = credential.first.value(); if (rule.credential_to_name.find(credential.first.value()) != rule.credential_to_name.end()) { LOG_WARN(absl::StrCat("duplicate consumer credential: ", @@ -113,15 +113,59 @@ bool PluginRootContext::parsePluginConfig(const json& configuration, rule.credentials.insert(credential.first.value()); rule.credential_to_name.emplace( std::make_pair(credential.first.value(), name.first.value())); + item = consumer.find("keys"); + if (item != consumer.end()) { + c.keys = std::vector{OriginalAuthKey}; + if (!JsonArrayIterate( + consumer, "keys", [&](const json& key_json) -> bool { + auto key = JsonValueAs(key_json); + if (key.second != + Wasm::Common::JsonParserResultDetail::OK) { + return false; + } + c.keys->push_back(key.first.value()); + return true; + })) { + LOG_WARN("failed to parse configuration for consumer keys."); + return false; + } + item = consumer.find("in_query"); + if (item != consumer.end()) { + auto in_query = JsonValueAs(item.value()); + if (in_query.second != + Wasm::Common::JsonParserResultDetail::OK || + !in_query.first) { + LOG_WARN( + "failed to parse 'in_query' field in consumer " + "configuration."); + return false; + } + c.in_query = in_query.first; + } + item = consumer.find("in_header"); + if (item != consumer.end()) { + auto in_header = JsonValueAs(item.value()); + if (in_header.second != + Wasm::Common::JsonParserResultDetail::OK || + !in_header.first) { + LOG_WARN( + "failed to parse 'in_header' field in consumer " + "configuration."); + return false; + } + c.in_header = in_header.first; + } + } + rule.consumers.push_back(std::move(c)); return true; })) { LOG_WARN("failed to parse configuration for credentials."); return false; } - if (rule.credentials.empty()) { - LOG_INFO("at least one credential has to be configured for a rule."); - return false; - } + // if (rule.credentials.empty()) { + // LOG_INFO("at least one credential has to be configured for a rule."); + // return false; + // } if (!JsonArrayIterate(configuration, "keys", [&](const json& item) -> bool { auto key = JsonValueAs(item); if (key.second != Wasm::Common::JsonParserResultDetail::OK) { @@ -137,6 +181,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration, LOG_WARN("at least one key has to be configured for a rule."); return false; } + rule.keys.push_back(OriginalAuthKey); auto it = configuration.find("realm"); if (it != configuration.end()) { auto realm_string = JsonValueAs(it.value()); @@ -175,36 +220,102 @@ bool PluginRootContext::parsePluginConfig(const json& configuration, bool PluginRootContext::checkPlugin( const KeyAuthConfigRule& rule, const std::optional>& allow_set) { - auto credential = extractCredential(rule); - if (credential.empty()) { - LOG_DEBUG("empty credential"); - deniedNoKeyAuthData(rule.realm); - return false; - } - auto auth_credential_iter = rule.credentials.find(std::string(credential)); - // Check if the credential is part of the credentials - // set from our container to grant or deny access. - if (auth_credential_iter == rule.credentials.end()) { - LOG_DEBUG(absl::StrCat("api key not found: ", credential)); - deniedInvalidCredentials(rule.realm); - return false; - } - // Check if this credential has a consumer name. If so, check if this - // consumer is allowed to access. If allow_set is empty, allow all consumers. - auto credential_to_name_iter = - rule.credential_to_name.find(std::string(std::string(credential))); - if (credential_to_name_iter != rule.credential_to_name.end()) { - if (allow_set && !allow_set.value().empty()) { - if (allow_set.value().find(credential_to_name_iter->second) == - allow_set.value().end()) { - deniedUnauthorizedConsumer(rule.realm); - LOG_DEBUG(credential_to_name_iter->second); - return false; + if (rule.consumers.empty()) { + for (const auto& key : rule.keys) { + auto credential = extractCredential(rule.in_header, rule.in_query, key); + if (credential.empty()) { + LOG_DEBUG("empty credential for key: " + key); + continue; + } + + auto auth_credential_iter = rule.credentials.find(credential); + if (auth_credential_iter == rule.credentials.end()) { + LOG_DEBUG("api key not found: " + credential); + continue; + } + + auto credential_to_name_iter = rule.credential_to_name.find(credential); + if (credential_to_name_iter != rule.credential_to_name.end()) { + if (allow_set && !allow_set->empty()) { + if (allow_set->find(credential_to_name_iter->second) == + allow_set->end()) { + deniedUnauthorizedConsumer(rule.realm); + LOG_DEBUG("unauthorized consumer: " + + credential_to_name_iter->second); + return false; + } + } + addRequestHeader("X-Mse-Consumer", credential_to_name_iter->second); + } + return true; + } + } else { + for (const auto& consumer : rule.consumers) { + std::vector keys_to_check = + consumer.keys.value_or(rule.keys); + bool in_query = consumer.in_query.value_or(rule.in_query); + bool in_header = consumer.in_header.value_or(rule.in_header); + + for (const auto& key : keys_to_check) { + auto credential = extractCredential(in_header, in_query, key); + if (credential.empty()) { + LOG_DEBUG("empty credential for key: " + key); + continue; + } + + if (credential != consumer.credential) { + LOG_DEBUG("credential does not match the consumer's credential: " + + credential); + continue; + } + + auto auth_credential_iter = rule.credentials.find(credential); + if (auth_credential_iter == rule.credentials.end()) { + LOG_DEBUG("api key not found: " + credential); + continue; + } + + auto credential_to_name_iter = rule.credential_to_name.find(credential); + if (credential_to_name_iter != rule.credential_to_name.end()) { + if (allow_set && !allow_set->empty()) { + if (allow_set->find(credential_to_name_iter->second) == + allow_set->end()) { + deniedUnauthorizedConsumer(rule.realm); + LOG_DEBUG("unauthorized consumer: " + + credential_to_name_iter->second); + return false; + } + } + addRequestHeader("X-Mse-Consumer", credential_to_name_iter->second); + } + return true; } } - addRequestHeader("X-Mse-Consumer", credential_to_name_iter->second); } - return true; + + LOG_DEBUG("No valid credentials were found after checking all consumers."); + deniedInvalidCredentials(rule.realm); + return false; +} + +std::string PluginRootContext::extractCredential(bool in_header, bool in_query, + const std::string& key) { + if (in_header) { + auto header = getRequestHeader(key); + if (header->size() != 0) { + return header->toString(); + } + } + if (in_query) { + auto request_path_header = getRequestHeader(":path"); + auto path = request_path_header->view(); + auto params = Wasm::Common::Http::parseAndDecodeQueryString(path); + auto it = params.find(key); + if (it != params.end()) { + return it->second; + } + } + return ""; } bool PluginRootContext::onConfigure(size_t size) { @@ -234,31 +345,6 @@ bool PluginRootContext::configure(size_t configuration_size) { return true; } -std::string PluginRootContext::extractCredential( - const KeyAuthConfigRule& rule) { - auto request_path_header = getRequestHeader(":path"); - auto path = request_path_header->view(); - LOG_DEBUG(std::string(path)); - if (rule.in_query) { - auto params = Wasm::Common::Http::parseAndDecodeQueryString(path); - for (const auto& key : rule.keys) { - auto it = params.find(key); - if (it != params.end()) { - return it->second; - } - } - } - if (rule.in_header) { - for (const auto& key : rule.keys) { - auto header = getRequestHeader(key); - if (header->size() != 0) { - return header->toString(); - } - } - } - return ""; -} - FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) { auto* rootCtx = rootContext(); return rootCtx->checkAuthRule( diff --git a/plugins/wasm-cpp/extensions/key_auth/plugin.h b/plugins/wasm-cpp/extensions/key_auth/plugin.h index 4a2b7418d..4cf329d90 100644 --- a/plugins/wasm-cpp/extensions/key_auth/plugin.h +++ b/plugins/wasm-cpp/extensions/key_auth/plugin.h @@ -36,7 +36,16 @@ namespace key_auth { #endif +struct Consumer { + std::string name; + std::string credential; + std::optional> keys; + std::optional in_query = std::nullopt; + std::optional in_header = std::nullopt; +}; + struct KeyAuthConfigRule { + std::vector consumers; std::unordered_set credentials; std::unordered_map credential_to_name; std::string realm = "MSE Gateway"; @@ -61,7 +70,8 @@ class PluginRootContext : public RootContext, private: bool parsePluginConfig(const json&, KeyAuthConfigRule&) override; - std::string extractCredential(const KeyAuthConfigRule&); + std::string extractCredential(bool in_header, bool in_query, + const std::string& key); }; // Per-stream context. diff --git a/plugins/wasm-cpp/extensions/key_auth/plugin_test.cc b/plugins/wasm-cpp/extensions/key_auth/plugin_test.cc index fc70df1b0..2f60811e0 100644 --- a/plugins/wasm-cpp/extensions/key_auth/plugin_test.cc +++ b/plugins/wasm-cpp/extensions/key_auth/plugin_test.cc @@ -121,7 +121,7 @@ TEST_F(KeyAuthTest, InQuery) { "_rules_": [ { "_match_route_": ["test"], - "credentials":["abc"], + "credentials":["abc","def"], "keys": ["apiKey", "x-api-key"] } ] @@ -144,6 +144,10 @@ TEST_F(KeyAuthTest, InQuery) { path_ = "/test?hello=123&apiKey=123"; EXPECT_EQ(context_->onRequestHeaders(0, false), FilterHeadersStatus::StopIteration); + + path_ = "/test?hello=123&apiKey=123&x-api-key=def"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); } TEST_F(KeyAuthTest, InQueryWithConsumer) { @@ -173,6 +177,29 @@ TEST_F(KeyAuthTest, InQueryWithConsumer) { FilterHeadersStatus::StopIteration); } +TEST_F(KeyAuthTest, EmptyConsumer) { + std::string configuration = R"( +{ + "consumers" : [], + "keys" : [ "apiKey", "x-api-key" ], + "_rules_" : [ {"_match_route_" : ["test"], "allow" : []} ] +})"; + BufferBase buffer; + buffer.set(configuration); + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + route_name_ = "test"; + path_ = "/test?hello=1&apiKey=abc"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + + route_name_ = "test2"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); +} + TEST_F(KeyAuthTest, InHeader) { std::string configuration = R"( { @@ -240,6 +267,40 @@ TEST_F(KeyAuthTest, InHeaderWithConsumer) { FilterHeadersStatus::StopIteration); } +TEST_F(KeyAuthTest, ConsumerDifferentKey) { + std::string configuration = R"( +{ + "consumers" : [ {"credential" : "abc", "name" : "consumer1", "keys" : [ "apiKey" ]}, {"credential" : "123", "name" : "consumer2"} ], + "keys" : [ "apiKey2" ], + "_rules_" : [ {"_match_route_" : ["test"], "allow" : ["consumer1"]}, {"_match_route_" : ["test2"], "allow" : ["consumer2"]} ] +})"; + BufferBase buffer; + buffer.set(configuration); + EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration)) + .WillOnce([&buffer](WasmBufferType) { return &buffer; }); + EXPECT_TRUE(root_context_->configure(configuration.size())); + + route_name_ = "test"; + path_ = "/test?hello=1&apiKey=abc"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); + + route_name_ = "test"; + path_ = "/test?hello=1&apiKey2=abc"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + + route_name_ = "test"; + path_ = "/test?hello=123&apiKey2=123"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::StopIteration); + + route_name_ = "test2"; + path_ = "/test?hello=123&apiKey2=123"; + EXPECT_EQ(context_->onRequestHeaders(0, false), + FilterHeadersStatus::Continue); +} + } // namespace key_auth } // namespace null_plugin } // namespace proxy_wasm