key-auth consumer support set independent key source (#1392)

This commit is contained in:
澄潭
2024-10-15 20:52:03 +08:00
committed by GitHub
parent 0a112d1a1e
commit 85f8eb5166
3 changed files with 219 additions and 62 deletions

View File

@@ -42,10 +42,7 @@ static RegisterContextFactory register_KeyAuth(CONTEXT_FACTORY(PluginContext),
namespace {
void deniedNoKeyAuthData(const std::string& realm) {
sendLocalResponse(401, "No API key found in request", "",
{{"WWW-Authenticate", absl::StrCat("Key realm=", realm)}});
}
const std::string OriginalAuthKey("X-HI-ORIGINAL-AUTH");
void deniedInvalidCredentials(const std::string& realm) {
sendLocalResponse(401, "Request denied by Key Auth check. Invalid API key",
@@ -84,6 +81,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
}
if (!JsonArrayIterate(
configuration, "consumers", [&](const json& consumer) -> bool {
Consumer c;
auto item = consumer.find("name");
if (item == consumer.end()) {
LOG_WARN("can't find 'name' field in consumer.");
@@ -94,6 +92,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
!name.first) {
return false;
}
c.name = name.first.value();
item = consumer.find("credential");
if (item == consumer.end()) {
LOG_WARN("can't find 'credential' field in consumer.");
@@ -104,6 +103,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
!credential.first) {
return false;
}
c.credential = credential.first.value();
if (rule.credential_to_name.find(credential.first.value()) !=
rule.credential_to_name.end()) {
LOG_WARN(absl::StrCat("duplicate consumer credential: ",
@@ -113,15 +113,59 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
rule.credentials.insert(credential.first.value());
rule.credential_to_name.emplace(
std::make_pair(credential.first.value(), name.first.value()));
item = consumer.find("keys");
if (item != consumer.end()) {
c.keys = std::vector<std::string>{OriginalAuthKey};
if (!JsonArrayIterate(
consumer, "keys", [&](const json& key_json) -> bool {
auto key = JsonValueAs<std::string>(key_json);
if (key.second !=
Wasm::Common::JsonParserResultDetail::OK) {
return false;
}
c.keys->push_back(key.first.value());
return true;
})) {
LOG_WARN("failed to parse configuration for consumer keys.");
return false;
}
item = consumer.find("in_query");
if (item != consumer.end()) {
auto in_query = JsonValueAs<bool>(item.value());
if (in_query.second !=
Wasm::Common::JsonParserResultDetail::OK ||
!in_query.first) {
LOG_WARN(
"failed to parse 'in_query' field in consumer "
"configuration.");
return false;
}
c.in_query = in_query.first;
}
item = consumer.find("in_header");
if (item != consumer.end()) {
auto in_header = JsonValueAs<bool>(item.value());
if (in_header.second !=
Wasm::Common::JsonParserResultDetail::OK ||
!in_header.first) {
LOG_WARN(
"failed to parse 'in_header' field in consumer "
"configuration.");
return false;
}
c.in_header = in_header.first;
}
}
rule.consumers.push_back(std::move(c));
return true;
})) {
LOG_WARN("failed to parse configuration for credentials.");
return false;
}
if (rule.credentials.empty()) {
LOG_INFO("at least one credential has to be configured for a rule.");
return false;
}
// if (rule.credentials.empty()) {
// LOG_INFO("at least one credential has to be configured for a rule.");
// return false;
// }
if (!JsonArrayIterate(configuration, "keys", [&](const json& item) -> bool {
auto key = JsonValueAs<std::string>(item);
if (key.second != Wasm::Common::JsonParserResultDetail::OK) {
@@ -137,6 +181,7 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
LOG_WARN("at least one key has to be configured for a rule.");
return false;
}
rule.keys.push_back(OriginalAuthKey);
auto it = configuration.find("realm");
if (it != configuration.end()) {
auto realm_string = JsonValueAs<std::string>(it.value());
@@ -175,36 +220,102 @@ bool PluginRootContext::parsePluginConfig(const json& configuration,
bool PluginRootContext::checkPlugin(
const KeyAuthConfigRule& rule,
const std::optional<std::unordered_set<std::string>>& allow_set) {
auto credential = extractCredential(rule);
if (credential.empty()) {
LOG_DEBUG("empty credential");
deniedNoKeyAuthData(rule.realm);
return false;
}
auto auth_credential_iter = rule.credentials.find(std::string(credential));
// Check if the credential is part of the credentials
// set from our container to grant or deny access.
if (auth_credential_iter == rule.credentials.end()) {
LOG_DEBUG(absl::StrCat("api key not found: ", credential));
deniedInvalidCredentials(rule.realm);
return false;
}
// Check if this credential has a consumer name. If so, check if this
// consumer is allowed to access. If allow_set is empty, allow all consumers.
auto credential_to_name_iter =
rule.credential_to_name.find(std::string(std::string(credential)));
if (credential_to_name_iter != rule.credential_to_name.end()) {
if (allow_set && !allow_set.value().empty()) {
if (allow_set.value().find(credential_to_name_iter->second) ==
allow_set.value().end()) {
deniedUnauthorizedConsumer(rule.realm);
LOG_DEBUG(credential_to_name_iter->second);
return false;
if (rule.consumers.empty()) {
for (const auto& key : rule.keys) {
auto credential = extractCredential(rule.in_header, rule.in_query, key);
if (credential.empty()) {
LOG_DEBUG("empty credential for key: " + key);
continue;
}
auto auth_credential_iter = rule.credentials.find(credential);
if (auth_credential_iter == rule.credentials.end()) {
LOG_DEBUG("api key not found: " + credential);
continue;
}
auto credential_to_name_iter = rule.credential_to_name.find(credential);
if (credential_to_name_iter != rule.credential_to_name.end()) {
if (allow_set && !allow_set->empty()) {
if (allow_set->find(credential_to_name_iter->second) ==
allow_set->end()) {
deniedUnauthorizedConsumer(rule.realm);
LOG_DEBUG("unauthorized consumer: " +
credential_to_name_iter->second);
return false;
}
}
addRequestHeader("X-Mse-Consumer", credential_to_name_iter->second);
}
return true;
}
} else {
for (const auto& consumer : rule.consumers) {
std::vector<std::string> keys_to_check =
consumer.keys.value_or(rule.keys);
bool in_query = consumer.in_query.value_or(rule.in_query);
bool in_header = consumer.in_header.value_or(rule.in_header);
for (const auto& key : keys_to_check) {
auto credential = extractCredential(in_header, in_query, key);
if (credential.empty()) {
LOG_DEBUG("empty credential for key: " + key);
continue;
}
if (credential != consumer.credential) {
LOG_DEBUG("credential does not match the consumer's credential: " +
credential);
continue;
}
auto auth_credential_iter = rule.credentials.find(credential);
if (auth_credential_iter == rule.credentials.end()) {
LOG_DEBUG("api key not found: " + credential);
continue;
}
auto credential_to_name_iter = rule.credential_to_name.find(credential);
if (credential_to_name_iter != rule.credential_to_name.end()) {
if (allow_set && !allow_set->empty()) {
if (allow_set->find(credential_to_name_iter->second) ==
allow_set->end()) {
deniedUnauthorizedConsumer(rule.realm);
LOG_DEBUG("unauthorized consumer: " +
credential_to_name_iter->second);
return false;
}
}
addRequestHeader("X-Mse-Consumer", credential_to_name_iter->second);
}
return true;
}
}
addRequestHeader("X-Mse-Consumer", credential_to_name_iter->second);
}
return true;
LOG_DEBUG("No valid credentials were found after checking all consumers.");
deniedInvalidCredentials(rule.realm);
return false;
}
std::string PluginRootContext::extractCredential(bool in_header, bool in_query,
const std::string& key) {
if (in_header) {
auto header = getRequestHeader(key);
if (header->size() != 0) {
return header->toString();
}
}
if (in_query) {
auto request_path_header = getRequestHeader(":path");
auto path = request_path_header->view();
auto params = Wasm::Common::Http::parseAndDecodeQueryString(path);
auto it = params.find(key);
if (it != params.end()) {
return it->second;
}
}
return "";
}
bool PluginRootContext::onConfigure(size_t size) {
@@ -234,31 +345,6 @@ bool PluginRootContext::configure(size_t configuration_size) {
return true;
}
std::string PluginRootContext::extractCredential(
const KeyAuthConfigRule& rule) {
auto request_path_header = getRequestHeader(":path");
auto path = request_path_header->view();
LOG_DEBUG(std::string(path));
if (rule.in_query) {
auto params = Wasm::Common::Http::parseAndDecodeQueryString(path);
for (const auto& key : rule.keys) {
auto it = params.find(key);
if (it != params.end()) {
return it->second;
}
}
}
if (rule.in_header) {
for (const auto& key : rule.keys) {
auto header = getRequestHeader(key);
if (header->size() != 0) {
return header->toString();
}
}
}
return "";
}
FilterHeadersStatus PluginContext::onRequestHeaders(uint32_t, bool) {
auto* rootCtx = rootContext();
return rootCtx->checkAuthRule(

View File

@@ -36,7 +36,16 @@ namespace key_auth {
#endif
struct Consumer {
std::string name;
std::string credential;
std::optional<std::vector<std::string>> keys;
std::optional<bool> in_query = std::nullopt;
std::optional<bool> in_header = std::nullopt;
};
struct KeyAuthConfigRule {
std::vector<Consumer> consumers;
std::unordered_set<std::string> credentials;
std::unordered_map<std::string, std::string> credential_to_name;
std::string realm = "MSE Gateway";
@@ -61,7 +70,8 @@ class PluginRootContext : public RootContext,
private:
bool parsePluginConfig(const json&, KeyAuthConfigRule&) override;
std::string extractCredential(const KeyAuthConfigRule&);
std::string extractCredential(bool in_header, bool in_query,
const std::string& key);
};
// Per-stream context.

View File

@@ -121,7 +121,7 @@ TEST_F(KeyAuthTest, InQuery) {
"_rules_": [
{
"_match_route_": ["test"],
"credentials":["abc"],
"credentials":["abc","def"],
"keys": ["apiKey", "x-api-key"]
}
]
@@ -144,6 +144,10 @@ TEST_F(KeyAuthTest, InQuery) {
path_ = "/test?hello=123&apiKey=123";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
path_ = "/test?hello=123&apiKey=123&x-api-key=def";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
}
TEST_F(KeyAuthTest, InQueryWithConsumer) {
@@ -173,6 +177,29 @@ TEST_F(KeyAuthTest, InQueryWithConsumer) {
FilterHeadersStatus::StopIteration);
}
TEST_F(KeyAuthTest, EmptyConsumer) {
std::string configuration = R"(
{
"consumers" : [],
"keys" : [ "apiKey", "x-api-key" ],
"_rules_" : [ {"_match_route_" : ["test"], "allow" : []} ]
})";
BufferBase buffer;
buffer.set(configuration);
EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration))
.WillOnce([&buffer](WasmBufferType) { return &buffer; });
EXPECT_TRUE(root_context_->configure(configuration.size()));
route_name_ = "test";
path_ = "/test?hello=1&apiKey=abc";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
route_name_ = "test2";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
}
TEST_F(KeyAuthTest, InHeader) {
std::string configuration = R"(
{
@@ -240,6 +267,40 @@ TEST_F(KeyAuthTest, InHeaderWithConsumer) {
FilterHeadersStatus::StopIteration);
}
TEST_F(KeyAuthTest, ConsumerDifferentKey) {
std::string configuration = R"(
{
"consumers" : [ {"credential" : "abc", "name" : "consumer1", "keys" : [ "apiKey" ]}, {"credential" : "123", "name" : "consumer2"} ],
"keys" : [ "apiKey2" ],
"_rules_" : [ {"_match_route_" : ["test"], "allow" : ["consumer1"]}, {"_match_route_" : ["test2"], "allow" : ["consumer2"]} ]
})";
BufferBase buffer;
buffer.set(configuration);
EXPECT_CALL(*mock_context_, getBuffer(WasmBufferType::PluginConfiguration))
.WillOnce([&buffer](WasmBufferType) { return &buffer; });
EXPECT_TRUE(root_context_->configure(configuration.size()));
route_name_ = "test";
path_ = "/test?hello=1&apiKey=abc";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
route_name_ = "test";
path_ = "/test?hello=1&apiKey2=abc";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
route_name_ = "test";
path_ = "/test?hello=123&apiKey2=123";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::StopIteration);
route_name_ = "test2";
path_ = "/test?hello=123&apiKey2=123";
EXPECT_EQ(context_->onRequestHeaders(0, false),
FilterHeadersStatus::Continue);
}
} // namespace key_auth
} // namespace null_plugin
} // namespace proxy_wasm