diff --git a/test/auth/pam_authenticator.cc b/test/auth/pam_authenticator.cc index 36689cd..1946111 100644 --- a/test/auth/pam_authenticator.cc +++ b/test/auth/pam_authenticator.cc @@ -51,10 +51,10 @@ TEST(pam_authenticator, get_type) TEST(pam_authenticator, authenticate) { MockPam pam; - EXPECT_CALL(pam, start(StrEq("webfused"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_start(StrEq("webfused"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); wfd_authenticator authenticator; bool success = wfd_pam_authenticator_create(nullptr, &authenticator); @@ -73,10 +73,10 @@ TEST(pam_authenticator, authenticate) TEST(pam_authenticator, authenticate_with_custom_service_name) { MockPam pam; - EXPECT_CALL(pam, start(StrEq("brummni"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_start(StrEq("brummni"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); MockSettings settings; EXPECT_CALL(settings, getStringOrDefault(StrEq("service_name"), StrEq("webfused"))) @@ -133,11 +133,11 @@ int valid_conversation( TEST(pam_authenticator, conversation_with_valid_messages) { MockPam pam; - EXPECT_CALL(pam, start(StrEq("webfused"), nullptr, _, _)) + EXPECT_CALL(pam, pam_start(StrEq("webfused"), nullptr, _, _)) .Times(1).WillOnce(Invoke(&valid_conversation)); - EXPECT_CALL(pam, authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); wfd_authenticator authenticator; bool success = wfd_pam_authenticator_create(nullptr, &authenticator); @@ -182,7 +182,7 @@ int invalid_conversation( TEST(pam_authenticator, conversation_with_invalid_messages) { MockPam pam; - EXPECT_CALL(pam, start(StrEq("webfused"), nullptr, _, _)) + EXPECT_CALL(pam, pam_start(StrEq("webfused"), nullptr, _, _)) .Times(1).WillOnce(Invoke(&invalid_conversation)); wfd_authenticator authenticator; @@ -202,9 +202,9 @@ TEST(pam_authenticator, conversation_with_invalid_messages) TEST(pam_authenticator, authenticate_fail_authenticate) { MockPam pam; - EXPECT_CALL(pam, start(StrEq("webfused"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(-1)); - EXPECT_CALL(pam, end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_start(StrEq("webfused"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(-1)); + EXPECT_CALL(pam, pam_end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); wfd_authenticator authenticator; bool success = wfd_pam_authenticator_create(nullptr, &authenticator); @@ -223,10 +223,10 @@ TEST(pam_authenticator, authenticate_fail_authenticate) TEST(pam_authenticator, authenticate_fail_acct_mgmt) { MockPam pam; - EXPECT_CALL(pam, start(StrEq("webfused"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); - EXPECT_CALL(pam, acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(-1)); - EXPECT_CALL(pam, end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_start(StrEq("webfused"), nullptr, _, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_authenticate(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(PAM_SUCCESS)); + EXPECT_CALL(pam, pam_acct_mgmt(_, PAM_DISALLOW_NULL_AUTHTOK)).Times(1).WillOnce(Return(-1)); + EXPECT_CALL(pam, pam_end(_, _)).Times(1).WillOnce(Return(PAM_SUCCESS)); wfd_authenticator authenticator; bool success = wfd_pam_authenticator_create(nullptr, &authenticator); diff --git a/test/mock/pam.cc b/test/mock/pam.cc index 1155e68..44ff44f 100644 --- a/test/mock/pam.cc +++ b/test/mock/pam.cc @@ -1,85 +1,15 @@ #include "mock/pam.hpp" +#include "util/wrap.hpp" extern "C" { - static webfused_test::IPam * wfd_MockPam = nullptr; -extern int __real_pam_start( - char const * service_name, - char const * user, - struct pam_conv const * conversation, - pam_handle_t * * handle); -extern int __real_pam_end(pam_handle_t * handle, int status); -extern int __real_pam_authenticate(pam_handle_t * handle, int flags); -extern int __real_pam_acct_mgmt(pam_handle_t * handle, int flags); -extern char const * __real_pam_strerror(pam_handle_t * handle, int errnum); - -int __wrap_pam_start( - char const * service_name, - char const * user, - struct pam_conv const * conversation, - pam_handle_t * * handle) -{ - if (nullptr == wfd_MockPam) - { - return __real_pam_start(service_name, user, conversation, handle); - } - else - { - return wfd_MockPam->start(service_name, user, conversation, handle); - } - -} - -int __wrap_pam_end(pam_handle_t * handle, int status) -{ - if (nullptr == wfd_MockPam) - { - return __real_pam_end(handle, status); - } - else - { - return wfd_MockPam->end(handle, status); - } -} - -int __wrap_pam_authenticate(pam_handle_t * handle, int flags) -{ - if (nullptr == wfd_MockPam) - { - return __real_pam_authenticate(handle, flags); - } - else - { - return wfd_MockPam->authenticate(handle, flags); - } -} - -int __wrap_pam_acct_mgmt(pam_handle_t * handle, int flags) -{ - if (nullptr == wfd_MockPam) - { - return __real_pam_acct_mgmt(handle, flags); - } - else - { - return wfd_MockPam->acct_mgmt(handle, flags); - } -} - -char const * __wrap_pam_strerror(pam_handle_t * handle, int errnum) -{ - if (nullptr == wfd_MockPam) - { - return __real_pam_strerror(handle, errnum); - } - else - { - return wfd_MockPam->strerror(handle, errnum); - } -} - +WFD_WRAP_FUNC4(wfd_MockPam, int, pam_start, char const *, char const *, struct pam_conv const *, pam_handle_t **); +WFD_WRAP_FUNC2(wfd_MockPam, int, pam_end, pam_handle_t *, int); +WFD_WRAP_FUNC2(wfd_MockPam, int, pam_authenticate, pam_handle_t *, int); +WFD_WRAP_FUNC2(wfd_MockPam, int, pam_acct_mgmt, pam_handle_t *, int); +WFD_WRAP_FUNC2(wfd_MockPam, char const *, pam_strerror, pam_handle_t *, int); } diff --git a/test/mock/pam.hpp b/test/mock/pam.hpp index a191dc8..cb6a19f 100644 --- a/test/mock/pam.hpp +++ b/test/mock/pam.hpp @@ -11,15 +11,15 @@ class IPam { public: virtual ~IPam() = default; - virtual int start( + virtual int pam_start( char const * service_name, char const * user, struct pam_conv const * conversation, pam_handle_t * * handle) = 0; - virtual int end(pam_handle_t * handle, int status) = 0; - virtual int authenticate(pam_handle_t * handle, int flags) = 0; - virtual int acct_mgmt(pam_handle_t * handle, int flags) = 0; - virtual char const * strerror(pam_handle_t * handle, int errnum) = 0; + virtual int pam_end(pam_handle_t * handle, int status) = 0; + virtual int pam_authenticate(pam_handle_t * handle, int flags) = 0; + virtual int pam_acct_mgmt(pam_handle_t * handle, int flags) = 0; + virtual char const * pam_strerror(pam_handle_t * handle, int errnum) = 0; }; class MockPam: public IPam @@ -28,16 +28,16 @@ public: MockPam(); ~MockPam() override; - MOCK_METHOD4(start, int ( + MOCK_METHOD4(pam_start, int ( char const * service_name, char const * user, struct pam_conv const * conversation, pam_handle_t * * handle)); - MOCK_METHOD2(end, int(pam_handle_t * handle, int status)); - MOCK_METHOD2(authenticate, int(pam_handle_t * handle, int flags)); - MOCK_METHOD2(acct_mgmt, int (pam_handle_t * handle, int flags)); - MOCK_METHOD2(strerror, char const * (pam_handle_t * handle, int errnum)); + MOCK_METHOD2(pam_end, int(pam_handle_t * handle, int status)); + MOCK_METHOD2(pam_authenticate, int(pam_handle_t * handle, int flags)); + MOCK_METHOD2(pam_acct_mgmt, int (pam_handle_t * handle, int flags)); + MOCK_METHOD2(pam_strerror, char const * (pam_handle_t * handle, int errnum)); }; } diff --git a/test/util/wrap.hpp b/test/util/wrap.hpp index efe7b0c..dc572ba 100644 --- a/test/util/wrap.hpp +++ b/test/util/wrap.hpp @@ -43,4 +43,32 @@ } \ } +#define WFD_WRAP_FUNC3( GLOBAL_VAR, RETURN_TYPE, FUNC_NAME, ARG1_TYPE, ARG2_TYPE, ARG3_TYPE ) \ + extern RETURN_TYPE __real_ ## FUNC_NAME (ARG1_TYPE, ARG2_TYPE, ARG3_TYPE); \ + RETURN_TYPE __wrap_ ## FUNC_NAME (ARG1_TYPE arg1, ARG2_TYPE arg2, ARG3_TYPE arg3) \ + { \ + if (nullptr == GLOBAL_VAR ) \ + { \ + return __real_ ## FUNC_NAME (arg1, arg2, arg3); \ + } \ + else \ + { \ + return GLOBAL_VAR -> FUNC_NAME(arg1, arg2, arg3); \ + } \ + } + +#define WFD_WRAP_FUNC4( GLOBAL_VAR, RETURN_TYPE, FUNC_NAME, ARG1_TYPE, ARG2_TYPE, ARG3_TYPE, ARG4_TYPE ) \ + extern RETURN_TYPE __real_ ## FUNC_NAME (ARG1_TYPE, ARG2_TYPE, ARG3_TYPE, ARG4_TYPE); \ + RETURN_TYPE __wrap_ ## FUNC_NAME (ARG1_TYPE arg1, ARG2_TYPE arg2, ARG3_TYPE arg3, ARG4_TYPE arg4) \ + { \ + if (nullptr == GLOBAL_VAR ) \ + { \ + return __real_ ## FUNC_NAME (arg1, arg2, arg3, arg4); \ + } \ + else \ + { \ + return GLOBAL_VAR -> FUNC_NAME(arg1, arg2, arg3, arg4); \ + } \ + } + #endif