diff --git a/src/http/request.cpp b/src/http/request.cpp index b20b435..25c30d6 100644 --- a/src/http/request.cpp +++ b/src/http/request.cpp @@ -17,31 +17,30 @@ namespace Http return static_cast(-1); } - Http::Request Request::Deserialize(std::vector const & bytes) + void Request::Deserialize(std::vector const & bytes, Logger & logger) { // TODO serialize more than just the start - Http::Request request; - std::stringstream ss(std::string(bytes.begin(), bytes.end())); std::string requestTypeString; ss >> requestTypeString; - request.type = ToEnum(requestTypeString, HttpRequest::typeStrings); - if (request.type == HttpRequest::Type::UNKNOWN) + type = ToEnum(requestTypeString, HttpRequest::typeStrings); + if (type == HttpRequest::Type::UNKNOWN) { - throw std::runtime_error("Bad request type"); + logger.Error("Request::Deserialize: Bad request type"); + return; } std::string rawUrl; ss >> rawUrl; - if(!request.url.TryParseFromUrlString(rawUrl)) + if(!url.TryParseFromUrlString(rawUrl)) { - throw std::runtime_error("Bad url in request"); + logger.Error("Request::Deserialize: Bad url in request"); + type = HttpRequest::Type::UNKNOWN; + return; } std::string httpProtocolString; ss >> httpProtocolString; - - return request; } } diff --git a/src/http/request.hpp b/src/http/request.hpp index 47b6db4..4ae199e 100644 --- a/src/http/request.hpp +++ b/src/http/request.hpp @@ -1,5 +1,6 @@ #pragma once #include "../constants/httprequest.hpp" +#include "../logger.hpp" #include "url.hpp" namespace Http @@ -9,6 +10,6 @@ namespace Http HttpRequest::Type type; Url url; - static Request Deserialize(std::vector const & bytes); + void Deserialize(std::vector const & bytes, Logger & logger); }; } diff --git a/src/http/response.cpp b/src/http/response.cpp index 9a3be78..dd00384 100644 --- a/src/http/response.cpp +++ b/src/http/response.cpp @@ -4,30 +4,30 @@ namespace Http { std::vector Response::Serialize() const + { + // TODO implement headers properly + std::stringstream ss; + ss << "HTTP/1.1"; + ss << ' ' << HttpResponse::codeValues[static_cast(code)]; + ss << ' ' << HttpResponse::codeStrings[static_cast(code)]; + ss << "\r\n"; + ss << "Server: http-server/0.1\r\n"; + + if (contentType.size() > 0) { - // TODO implement headers properly - std::stringstream ss; - ss << "HTTP/1.1"; - ss << ' ' << HttpResponse::codeValues[static_cast(code)]; - ss << ' ' << HttpResponse::codeStrings[static_cast(code)]; - ss << "\r\n"; - ss << "Server: http-server/0.1\r\n"; - - if (contentType.size() > 0) - { - ss << "Content-Type: "; - ss << contentType << "\r\n"; - } - - ss << "\r\n"; - - auto header = ss.str(); - std::vector buffer (header.begin(), header.end()); - buffer.insert(buffer.end(), content.begin(), content.end()); - - return buffer; + ss << "Content-Type: "; + ss << contentType << "\r\n"; } + ss << "\r\n"; + + auto header = ss.str(); + std::vector buffer (header.begin(), header.end()); + buffer.insert(buffer.end(), content.begin(), content.end()); + + return buffer; + } + Response::Response() : code(HttpResponse::Code::UNKNOWN), contentType(), diff --git a/src/logger.cpp b/src/logger.cpp index e840667..7a08c8e 100644 --- a/src/logger.cpp +++ b/src/logger.cpp @@ -11,11 +11,21 @@ void Logger::Success(const std::string & s) Log("SUCCESS", s); } +void Logger::Fatal(const std::string & s) +{ + Log("FATAL", s); +} + void Logger::Error(const std::string & s) { Log("ERROR", s); } +void Logger::Warning(const std::string & s) +{ + Log("WARNING", s); +} + void Logger::Info(const std::string & s) { Log("INFO", s); diff --git a/src/logger.hpp b/src/logger.hpp index 7f28dc9..39c8441 100644 --- a/src/logger.hpp +++ b/src/logger.hpp @@ -5,7 +5,9 @@ class Logger { public: void Success(const std::string & s); + void Fatal(const std::string & s); void Error(const std::string & s); + void Warning(const std::string & s); void Info(const std::string & s); void Debug(const std::string & s); diff --git a/src/main.cpp b/src/main.cpp index a2cdad2..3c8a58c 100755 --- a/src/main.cpp +++ b/src/main.cpp @@ -5,9 +5,10 @@ int main(int argc, char ** argv) { - ServerConfiguration serverConfiguration; Logger logger; - HttpServer httpServer(logger, serverConfiguration); + + Server::Configuration serverConfiguration; + Server::HttpServer httpServer(logger, serverConfiguration); httpServer.Execute(); diff --git a/src/middleware/base.cpp b/src/middleware/base.cpp index a7ae324..a7a2686 100644 --- a/src/middleware/base.cpp +++ b/src/middleware/base.cpp @@ -2,6 +2,17 @@ namespace Middleware { + State::State(int const socketFileDescriptor, Http::Request const & _request, Http::Response & _response) + : finished(false), + awaiting(false), + socketFd(socketFileDescriptor), + awaitingFd(-1), + request(_request), + response(_response), + middlewareIndex(0) + { + } + BaseMiddleware::BaseMiddleware(Logger & _logger) : logger(_logger) { diff --git a/src/middleware/base.hpp b/src/middleware/base.hpp index 6c1bfe8..89da1df 100644 --- a/src/middleware/base.hpp +++ b/src/middleware/base.hpp @@ -2,18 +2,35 @@ #include "../http/request.hpp" #include "../http/response.hpp" #include "../logger.hpp" +#include "../scheduler.hpp" #include #include namespace Middleware { + struct State + { + bool finished; + bool awaiting; + + int const socketFd; + int awaitingFd; + + Http::Request const & request; + Http::Response & response; + + size_t middlewareIndex; + + State(int const socketFileDescriptor, Http::Request const & request, Http::Response & response); + }; + class BaseMiddleware { protected: Logger & logger; public: - virtual void HandleRequest(Http::Request const & request, Http::Response & response) = 0; + virtual void HandleRequest(State & state, ScheduleHelper & scheduleHelper) = 0; BaseMiddleware(Logger & logger); virtual ~BaseMiddleware() = default; diff --git a/src/middleware/middleware.cpp b/src/middleware/middleware.cpp new file mode 100644 index 0000000..0e7a3ea --- /dev/null +++ b/src/middleware/middleware.cpp @@ -0,0 +1,96 @@ +#include "middleware.hpp" +#include "notfound.hpp" +#include "staticcontent.hpp" +#include "unimplemented.hpp" +#include + +namespace Middleware +{ + MiddlewareResult::MiddlewareResult(bool const _writeResult, int const _socketFd) + : writeResult(_writeResult), + socketFd(_socketFd) + { + } + + Middleware::Middleware(Logger & _logger) + : logger(_logger) + { + } + + MiddlewareResult Middleware::Loop(State & state, ScheduleHelper & scheduleHelper) + { + for(size_t i = state.middlewareIndex; i < middlewares.size(); ++i) + { + middlewares[i]->HandleRequest(state, scheduleHelper); + if (state.finished) + { + break; + } + + if (state.awaiting) + { + state.middlewareIndex = i; + eventFdToSocketFdMap[state.awaitingFd] = state.socketFd; + return MiddlewareResult (false, state.socketFd); + } + } + + int const socketFd = state.socketFd; + auto stateResult = socketFdToStateMap.find(socketFd); + if (stateResult != socketFdToStateMap.end()) + { + socketFdToStateMap.erase(stateResult); + // The state parameter is now invalid!! + } + + return MiddlewareResult(true, socketFd); + } + + MiddlewareResult Middleware::Start(int const socketFd, ScheduleHelper & scheduleHelper, Http::Request const & request, Http::Response & response) + { + auto & statePtr = socketFdToStateMap[socketFd]; + statePtr = std::make_unique(socketFd, request, response); + + return Loop(*statePtr, scheduleHelper); + } + + MiddlewareResult Middleware::Continue(int const fd, ScheduleHelper & scheduleHelper) + { + auto socketFdResult = eventFdToSocketFdMap.find(fd); + if (socketFdResult == eventFdToSocketFdMap.end()) + { + close(fd); + logger.Warning("Dangeling filedescriptor closed in middleware"); + return MiddlewareResult(false, -1); + } + + int const socketFd = socketFdResult->second; + eventFdToSocketFdMap.erase(socketFdResult); + + auto stateResult = socketFdToStateMap.find(socketFd); + if (stateResult == socketFdToStateMap.end()) + { + close(fd); + close(socketFd); + logger.Warning("Stateless socket file descriptor and associated event file descriptor closed"); + return MiddlewareResult(false, -1); + } + + return Loop(*(stateResult->second), scheduleHelper); + } + + void Middleware::UseMiddleware(std::unique_ptr && middleware) + { + middlewares.emplace_back(std::move(middleware)); + } + + Middleware Middleware::CreateWithStaticFile(Logger & logger, std::string const & rootDirectory) + { + Middleware middleware(logger); + middleware.UseMiddleware(std::make_unique(logger, rootDirectory)); + middleware.UseMiddleware(std::make_unique(logger)); + middleware.UseMiddleware(std::make_unique(logger)); + + return middleware; + } +} \ No newline at end of file diff --git a/src/middleware/middleware.hpp b/src/middleware/middleware.hpp new file mode 100644 index 0000000..1d44208 --- /dev/null +++ b/src/middleware/middleware.hpp @@ -0,0 +1,41 @@ +#pragma once +#include "../http/request.hpp" +#include "../http/response.hpp" +#include "../logger.hpp" +#include "../scheduler.hpp" +#include "base.hpp" +#include +#include +#include + +namespace Middleware +{ + struct MiddlewareResult + { + bool const writeResult; + int const socketFd; + + MiddlewareResult(bool const writeResult, int const socketFd); + }; + + class Middleware + { + private: + Logger & logger; + std::vector> middlewares; + std::map> socketFdToStateMap; + std::map eventFdToSocketFdMap; + + Middleware(Logger & logger); + + MiddlewareResult Loop(State & state, ScheduleHelper & scheduleHelper); + + public: + MiddlewareResult Start(int const socketFd, ScheduleHelper & scheduleHelper, Http::Request const & request, Http::Response & response); + MiddlewareResult Continue(int const fd, ScheduleHelper & scheduleHelper); + + void UseMiddleware(std::unique_ptr && middleware); + + static Middleware CreateWithStaticFile(Logger & logger, std::string const & rootDirectory); + }; +} \ No newline at end of file diff --git a/src/middleware/notfound.cpp b/src/middleware/notfound.cpp index f8e8abd..e87f229 100644 --- a/src/middleware/notfound.cpp +++ b/src/middleware/notfound.cpp @@ -4,30 +4,32 @@ namespace Middleware { - void NotFound::HandleRequest(Http::Request const & request, Http::Response & response) + void NotFound::HandleRequest(State & state, ScheduleHelper & scheduleHelper) { - if (response.code != HttpResponse::Code::UNKNOWN || - !(request.type == HttpRequest::Type::GET || request.type == HttpRequest::Type::HEAD)) + if (state.response.code != HttpResponse::Code::UNKNOWN || + !(state.request.type == HttpRequest::Type::GET || state.request.type == HttpRequest::Type::HEAD)) { return; } - response.code = HttpResponse::Code::NOT_FOUND; - if (request.type == HttpRequest::Type::GET) + state.response.code = HttpResponse::Code::NOT_FOUND; + if (state.request.type == HttpRequest::Type::GET) { - response.contentType = Http::GetMimeType(Http::FileType::HTML); + state.response.contentType = Http::GetMimeType(Http::FileType::HTML); std::stringstream ss; ss << ""; ss << "

404 - File Not Found

"; - ss << "

File: " << request.url.GetPath() << "

"; + ss << "

File: " << state.request.url.GetPath() << "

"; ss << ""; auto responseContent = ss.str(); - response.content.insert(response.content.begin(), + state.response.content.insert(state.response.content.begin(), responseContent.begin(), responseContent.end()); } + + state.finished = true; } NotFound::NotFound(Logger & _logger) diff --git a/src/middleware/notfound.hpp b/src/middleware/notfound.hpp index e5da9c8..63045d3 100644 --- a/src/middleware/notfound.hpp +++ b/src/middleware/notfound.hpp @@ -6,7 +6,7 @@ namespace Middleware class NotFound : public BaseMiddleware { public: - void HandleRequest(Http::Request const & request, Http::Response & Response) override; + void HandleRequest(State & state, ScheduleHelper & scheduleHelper) override; NotFound(Logger & logger); }; diff --git a/src/middleware/staticcontent.cpp b/src/middleware/staticcontent.cpp index 776793f..082106c 100644 --- a/src/middleware/staticcontent.cpp +++ b/src/middleware/staticcontent.cpp @@ -4,6 +4,7 @@ #include #include "staticcontent.hpp" #include +#include namespace Middleware { @@ -48,17 +49,17 @@ namespace Middleware return false; } - void StaticContent::HandleRequest(Http::Request const & request, Http::Response & response) + void StaticContent::HandleRequest(State & state, ScheduleHelper & scheduleHelper) { - if (!(request.type == HttpRequest::Type::GET || request.type == HttpRequest::Type::HEAD)) + if (!(state.request.type == HttpRequest::Type::GET || state.request.type == HttpRequest::Type::HEAD)) { return; } std::string path; - if (request.url.HasPath()) + if (state.request.url.HasPath()) { - path = root + request.url.GetPath(); + path = root + state.request.url.GetPath(); } else { @@ -66,20 +67,31 @@ namespace Middleware path = root + "/index.html"; } - if (ContainsDoubleDots(request.url.GetPath())) + if (ContainsDoubleDots(state.request.url.GetPath())) { // We cannot deal with this, we are not going to bother checking if // this double dot escapes our root directory return; } - if (request.type == HttpRequest::Type::GET && !TryReadAllBytes(path, response.content)) + if (access(path.c_str(), F_OK) != 0) { + // File does not exist return; } - response.code = HttpResponse::Code::OK; - response.contentType = Http::GetMimeType(path); + state.response.code = HttpResponse::Code::OK; + state.response.contentType = Http::GetMimeType(path); + + // Regular file descriptors are not supported by epoll, so we have to "just read it" + if (state.request.type == HttpRequest::Type::GET && !TryReadAllBytes(path, state.response.content)) + { + state.response.code = HttpResponse::Code::INTERNAL_SERVER_ERROR; + state.response.contentType = Http::GetMimeType(""); + } + + // HEAD request + state.finished = true; return; } diff --git a/src/middleware/staticcontent.hpp b/src/middleware/staticcontent.hpp index 534e88a..cd592de 100644 --- a/src/middleware/staticcontent.hpp +++ b/src/middleware/staticcontent.hpp @@ -10,9 +10,9 @@ namespace Middleware { private: std::string root; - + public: - virtual void HandleRequest(Http::Request const & request, Http::Response & response) override; + virtual void HandleRequest(State & state, ScheduleHelper & scheduleHelper) override; StaticContent(Logger & logger, std::string const & staticFileRoot); }; diff --git a/src/middleware/unimplemented.cpp b/src/middleware/unimplemented.cpp new file mode 100644 index 0000000..a4ec765 --- /dev/null +++ b/src/middleware/unimplemented.cpp @@ -0,0 +1,15 @@ +#include "unimplemented.hpp" + +namespace Middleware +{ + void Unimplemented::HandleRequest(State & state, ScheduleHelper & scheduleHelper) + { + state.finished = true; + state.response.code = HttpResponse::Code::NOT_IMPLEMENTED; + } + + Unimplemented::Unimplemented(Logger & logger) + : BaseMiddleware(logger) + { + } +} \ No newline at end of file diff --git a/src/middleware/unimplemented.hpp b/src/middleware/unimplemented.hpp new file mode 100644 index 0000000..8244f5d --- /dev/null +++ b/src/middleware/unimplemented.hpp @@ -0,0 +1,13 @@ +#pragma once +#include "base.hpp" + +namespace Middleware +{ + class Unimplemented : public BaseMiddleware + { + public: + virtual void HandleRequest(State & state, ScheduleHelper & scheduleHelper) override; + + Unimplemented(Logger & logger); + }; +} \ No newline at end of file diff --git a/src/scheduler.cpp b/src/scheduler.cpp new file mode 100644 index 0000000..ce22288 --- /dev/null +++ b/src/scheduler.cpp @@ -0,0 +1,89 @@ +#include +#include "scheduler.hpp" +#include + +void Scheduler::ScheduleFileDescriptor(int fd, int eventFlags, bool isReschedule) +{ + if (fd < 0) + { + logger.Warning("Attempted to schedule a non valid filedescriptor"); + } + + epoll_event event; + event.data.fd = fd; + event.events = eventFlags; + int const epollCtl = isReschedule ? EPOLL_CTL_MOD : EPOLL_CTL_ADD; + if (epoll_ctl(epollFileDescriptor, epollCtl, fd, &event) < 0) + { + logger.Fatal(std::to_string(errno)); + throw std::runtime_error("Error registering filedescriptor with epoll facilities"); + } +} + +std::vector Scheduler::WaitForEvents() +{ + ScheduleHelper scheduleHelper(*this); + + int const eventsHappened = epoll_wait(epollFileDescriptor, epollEvents.data(), static_cast(epollEvents.size()), -1); + std::vector readyFileDescriptors; + if (eventsHappened > 0) + { + readyFileDescriptors.resize(eventsHappened); + for (int i = 0; i < eventsHappened; ++i) + { + readyFileDescriptors[i] = epollEvents[i].data.fd; + } + } + + return readyFileDescriptors; +} + +ScheduleHelper Scheduler::GetScheduleHelper() +{ + return ScheduleHelper(*this); +} + +Scheduler::Scheduler(Logger & _logger) + : logger(_logger) +{ + epollFileDescriptor = epoll_create1(0); + if (epollFileDescriptor < 0) + { + throw std::runtime_error("Error creating epoll file descriptor"); + } + + epollEvents.resize(512); +} + +Scheduler::~Scheduler() +{ + if (epollFileDescriptor >= 0) + { + close(epollFileDescriptor); + } +} + +void ScheduleHelper::ScheduleReadFileDescriptor(int fd, bool oneshot, bool isReschedule) +{ + int flags = EPOLLIN; + if (oneshot) + { + flags |= EPOLLONESHOT; + } + scheduler.ScheduleFileDescriptor(fd, flags, isReschedule); +} + +void ScheduleHelper::ScheduleWriteFileDescriptor(int fd, bool oneshot, bool isReschedule) +{ + int flags = EPOLLOUT; + if (oneshot) + { + flags |= EPOLLONESHOT; + } + scheduler.ScheduleFileDescriptor(fd, flags, isReschedule); +} + +ScheduleHelper::ScheduleHelper(Scheduler & _scheduler) + : scheduler(_scheduler) +{ +} \ No newline at end of file diff --git a/src/scheduler.hpp b/src/scheduler.hpp new file mode 100644 index 0000000..4dfa12a --- /dev/null +++ b/src/scheduler.hpp @@ -0,0 +1,47 @@ +#pragma once +#include "logger.hpp" +#include +#include +#include + +class ScheduleHelper; // Forward declare + +//using callbackFunction = void (*)(int, ScheduleHelper); + +class Scheduler +{ + friend class ScheduleHelper; +private: + Logger & logger; + int epollFileDescriptor; +// std::map fileDescriptorToCallbackMap; + std::vector epollEvents; + + void ScheduleFileDescriptor(int fd, int eventFlags, bool isReschedule); + +public: + // Blocks until an event happens + // Returns a vector filedescriptors that are ready for processing + std::vector WaitForEvents(); + + ScheduleHelper GetScheduleHelper(); + + Scheduler(Logger & logger); + ~Scheduler(); + Scheduler(Scheduler & other) = delete; + Scheduler(Scheduler && other) = delete; + Scheduler & operator=(Scheduler & other) = delete; + Scheduler & operator=(Scheduler && other) = delete; +}; + +class ScheduleHelper +{ +private: + Scheduler & scheduler; + +public: + void ScheduleReadFileDescriptor(int fd, bool oneshot = true, bool isReschedule = false); + void ScheduleWriteFileDescriptor(int fd, bool oneshot = true, bool isReschedule = false); + + ScheduleHelper(Scheduler & scheduler); +}; \ No newline at end of file diff --git a/src/server/configuration.cpp b/src/server/configuration.cpp index 3446cd0..9e0ce11 100644 --- a/src/server/configuration.cpp +++ b/src/server/configuration.cpp @@ -1,44 +1,47 @@ #include "configuration.hpp" -int ServerConfiguration::GetMajorVersion() const +namespace Server { - return 0; -} + int Configuration::GetMajorVersion() const + { + return 0; + } -int ServerConfiguration::GetMinorVersion() const -{ - return 1; -} + int Configuration::GetMinorVersion() const + { + return 1; + } -std::string const & ServerConfiguration::GetWwwRoot() const -{ - return wwwRoot; -} + std::string const & Configuration::GetWwwRoot() const + { + return wwwRoot; + } -std::string const & ServerConfiguration::GetServerName() const -{ - return serverName; -} + std::string const & Configuration::GetServerName() const + { + return serverName; + } -int ServerConfiguration::GetPort() const -{ - return port; -} + int Configuration::GetPort() const + { + return port; + } -bool ServerConfiguration::IsValid() const -{ - return isValid; -} + bool Configuration::IsValid() const + { + return isValid; + } -bool ServerConfiguration::LoadFromFile(std::string const & filePath) -{ - // TODO implement - return false; -} + bool Configuration::LoadFromFile(std::string const & filePath) + { + // TODO implement + return false; + } -ServerConfiguration::ServerConfiguration() - : wwwRoot("./www"), - serverName("http-server"), - port(8080) -{ + Configuration::Configuration() + : wwwRoot("./www"), + serverName("http-server"), + port(8080) + { + } } \ No newline at end of file diff --git a/src/server/configuration.hpp b/src/server/configuration.hpp index a0e4585..46911d4 100644 --- a/src/server/configuration.hpp +++ b/src/server/configuration.hpp @@ -1,27 +1,30 @@ #pragma once #include -class ServerConfiguration +namespace Server { -private: - std::string wwwRoot; - std::string serverName; - int port; - bool isValid; + class Configuration + { + private: + std::string wwwRoot; + std::string serverName; + int port; + bool isValid; -public: - int GetMajorVersion() const; - int GetMinorVersion() const; - std::string const & GetWwwRoot() const; - std::string const & GetServerName() const; - int GetPort() const; - bool IsValid() const; + public: + int GetMajorVersion() const; + int GetMinorVersion() const; + std::string const & GetWwwRoot() const; + std::string const & GetServerName() const; + int GetPort() const; + bool IsValid() const; - bool LoadFromFile(std::string const & filePath); + bool LoadFromFile(std::string const & filePath); - ServerConfiguration(); - ~ServerConfiguration() = default; + Configuration(); + ~Configuration() = default; - ServerConfiguration(ServerConfiguration & other) = delete; - ServerConfiguration(ServerConfiguration && other) = delete; -}; \ No newline at end of file + Configuration(Configuration & other) = delete; + Configuration(Configuration && other) = delete; + }; +} \ No newline at end of file diff --git a/src/server/connectionoperator.cpp b/src/server/connectionoperator.cpp deleted file mode 100644 index 505e4f5..0000000 --- a/src/server/connectionoperator.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include "../middleware/notfound.hpp" -#include "../middleware/staticcontent.hpp" -#include "../logger.hpp" -#include "configuration.hpp" -#include "connectionoperator.hpp" -#include -#include - -std::vector ConnectionOperator::HandleNewConnection(int fd) -{ - auto requestBytes = Socket::ReadBytes(fd, 512); - Http::Request request; - Http::Response response; - try - { - request = Http::Request::Deserialize(requestBytes); - } - catch (std::runtime_error & e) - { - std::stringstream ss; - ss << "Error during parsing of request <"; - ss << e.what(); - ss << '>'; - logger.Error(ss.str()); - - response.code = HttpResponse::Code::BAD_REQUEST; - - return response.Serialize(); - } - - for(size_t i = 0; i < middlewares.size(); ++i) - { - middlewares[i]->HandleRequest(request, response); - } - - if (response.code == HttpResponse::Code::UNKNOWN) - { - std::stringstream ss; - ss << "Unhandled "; - ss << HttpRequest::typeStrings[static_cast(request.type)]; - ss << " request for file <"; - ss << request.url.GetPath(); - ss << '>'; - logger.Error(ss.str()); - - response.code = HttpResponse::Code::NOT_IMPLEMENTED; - - return response.Serialize(); - } - - return response.Serialize(); -} - -ConnectionOperator::ConnectionOperator(Logger & _logger, ServerConfiguration const & serverConfiguration) - : logger(_logger) -{ - // Base static file server - auto const & staticFileRoot = serverConfiguration.GetWwwRoot(); - if (staticFileRoot.size() > 0) - { - middlewares.emplace_back(std::make_unique(_logger, staticFileRoot)); - } - - // ALWAYS LAST! - middlewares.emplace_back(std::make_unique(_logger)); -} \ No newline at end of file diff --git a/src/server/connectionoperator.hpp b/src/server/connectionoperator.hpp deleted file mode 100644 index 7bcbb91..0000000 --- a/src/server/connectionoperator.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once -#include "../logger.hpp" -#include "../middleware/base.hpp" -#include -#include "socket.hpp" -#include -#include - -class ConnectionOperator -{ -private: - Logger & logger; - std::vector> middlewares; - -public: - std::vector HandleNewConnection(int fd); - - ConnectionOperator(Logger & logger, ServerConfiguration const & serverConfiguration); -}; \ No newline at end of file diff --git a/src/server/eventtype.hpp b/src/server/eventtype.hpp new file mode 100644 index 0000000..16a9818 --- /dev/null +++ b/src/server/eventtype.hpp @@ -0,0 +1,12 @@ +#pragma once + +namespace Server +{ + enum class EventType + { + Unknown = -1, + NewConnection = 0, + ReadRequest = 1, + WriteResponse + }; +} diff --git a/src/server/server.cpp b/src/server/server.cpp index c70e355..dd2fb2e 100755 --- a/src/server/server.cpp +++ b/src/server/server.cpp @@ -2,119 +2,163 @@ #include "configuration.hpp" #include "server.hpp" #include +#include "socket.hpp" #include #include #include #include -void HttpServer::HandleEpollInEvent(int fd) +namespace Server { - if (fd != listeningSocketFileDescriptor) + void HttpServer::RegisterNewConnection(int const listeningSocketFd) { - logger.Info("EPOLLIN Attempted to handle a non registered file descriptor"); - - return; - } - - unsigned sockaddrSize = sizeof(sockaddr_in); - int const connectionFileDescriptor = accept( - listeningSocketFileDescriptor, - reinterpret_cast(&socketAddress), - &sockaddrSize); - - socketWriteMap[connectionFileDescriptor] = connectionOperator.HandleNewConnection(connectionFileDescriptor); - - epoll_event event; - event.data.fd = connectionFileDescriptor; - event.events = EPOLLOUT | EPOLLONESHOT; - if (epoll_ctl(pollFileDescriptor, EPOLL_CTL_ADD, connectionFileDescriptor, &event) < 0) - { - logger.Error("Error registering file descriptor for EPOLLOUT"); - } -} - -void HttpServer::HandleEpollOutEvent(int fd) -{ - auto result = socketWriteMap.find(fd); - if (result == socketWriteMap.end()) - { - logger.Error("EPOLLOUT Event for unexpected fd"); - return; - } - - Socket::WriteBytes(fd, result->second); - close(fd); - socketWriteMap.erase(result); -} - -void HttpServer::Execute() -{ - std::vector epollEvents; - epollEvents.resize(1000); - - while(isOpen) - { - int eventsHappened = epoll_wait(pollFileDescriptor, epollEvents.data(), static_cast(epollEvents.size()), -1); - for (int i = 0; i < eventsHappened; ++i) + unsigned sockaddrSize = sizeof(sockaddr_in); + int const clientFd = accept( + listeningSocketFd, + reinterpret_cast(&socketAddress), + &sockaddrSize); + + if (clientFd < 0) { - int const fd = epollEvents[i].data.fd; - if (epollEvents[i].events & EPOLLIN) + logger.Info("Connection dropped on accepting it"); + return; + } + + scheduleHelper.ScheduleReadFileDescriptor(clientFd); + fileDescriptorEventTypeMap[clientFd] = EventType::ReadRequest; + } + + void HttpServer::ReadRequest(int const socketFd) + { + auto bytes = Socket::ReadBytes(socketFd, 512); + if (bytes.size() < 1) + { + logger.Info("Could not read anything from socket"); + return; + } + + auto & request = socketRequestMap[socketFd]; + request.Deserialize(bytes, logger); + + auto & response = socketResponseMap[socketFd]; + + auto middlewareResult = middleware.Start(socketFd, scheduleHelper, request, response); + HandleMiddlewareResponse(middlewareResult); + } + + void HttpServer::HandleMiddlewareEvent(int const eventFd) + { + auto middlewareResult = middleware.Continue(eventFd, scheduleHelper); + HandleMiddlewareResponse(middlewareResult); + } + + void HttpServer::HandleMiddlewareResponse(Middleware::MiddlewareResult const & middlewareResult) + { + if (middlewareResult.writeResult) + { + fileDescriptorEventTypeMap[middlewareResult.socketFd] = EventType::WriteResponse; + scheduleHelper.ScheduleWriteFileDescriptor(middlewareResult.socketFd, true, true); + } + } + + void HttpServer::WriteResponse(int const socketFd) + { + auto requestResult = socketRequestMap.find(socketFd); + if (requestResult != socketRequestMap.end()) + { + socketRequestMap.erase(requestResult); + } + + auto responseResult = socketResponseMap.find(socketFd); + Socket::WriteBytes(socketFd, responseResult->second.Serialize()); + if (responseResult != socketResponseMap.end()) + { + socketResponseMap.erase(responseResult); + } + } + + void HttpServer::Execute() + { + while(isOpen) + { + auto events = scheduler.WaitForEvents(); + for (size_t i = 0; i < events.size(); ++i) { - HandleEpollInEvent(fd); - } - else - { - HandleEpollOutEvent(fd); + int const fd = events[i]; + + auto eventTypeMapIter = fileDescriptorEventTypeMap.find(fd); + if (eventTypeMapIter == fileDescriptorEventTypeMap.end()) + { + HandleMiddlewareEvent(fd); + continue; + } + + if (eventTypeMapIter->second != EventType::NewConnection) + { + fileDescriptorEventTypeMap.erase(eventTypeMapIter); + } + + switch(eventTypeMapIter->second) + { + case EventType::NewConnection: + RegisterNewConnection(fd); + break; + + case EventType::ReadRequest: + ReadRequest(fd); + break; + + case EventType::WriteResponse: + WriteResponse(fd); + close(fd); + break; + + case EventType::Unknown: + default: + close(fd); + logger.Error("FD with unknown EventType fired and was closed"); + break; + } } } } + + HttpServer::HttpServer(Logger & _logger, Configuration const & serverConfiguration) + : logger(_logger), + scheduler(_logger), + scheduleHelper(scheduler), + listeningSocketFileDescriptor(-1), + middleware(Middleware::Middleware::CreateWithStaticFile(_logger, serverConfiguration.GetWwwRoot())), + isOpen(true) + { + socketAddress.sin_family = AF_INET; + socketAddress.sin_addr.s_addr = INADDR_ANY; + socketAddress.sin_port = htons(serverConfiguration.GetPort()); + + listeningSocketFileDescriptor = ListeningSocket::Create(socketAddress, 1000); + if (listeningSocketFileDescriptor < 0) + { + throw std::runtime_error("Error creating listening socket"); + } + + scheduleHelper.ScheduleReadFileDescriptor(listeningSocketFileDescriptor, false); + fileDescriptorEventTypeMap[listeningSocketFileDescriptor] = EventType::NewConnection; + + std::stringstream ss; + ss << "Listening on port " << serverConfiguration.GetPort(); + logger.Info(ss.str()); + } + + HttpServer::~HttpServer() + { + if (listeningSocketFileDescriptor >= 0) + { + close(listeningSocketFileDescriptor); + } + + for (auto & item : fileDescriptorEventTypeMap) + { + close(item.first); + } + } } - -HttpServer::HttpServer(Logger & _logger, ServerConfiguration const & serverConfiguration) - : logger(_logger), - pollFileDescriptor(-1), - listeningSocketFileDescriptor(-1), - connectionOperator(_logger, serverConfiguration), - isOpen(true) -{ - pollFileDescriptor = epoll_create1(0); - if (pollFileDescriptor < 0) - { - throw std::runtime_error("Error creating epoll file descriptor"); - } - - socketAddress.sin_family = AF_INET; - socketAddress.sin_addr.s_addr = INADDR_ANY; - socketAddress.sin_port = htons(serverConfiguration.GetPort()); - - listeningSocketFileDescriptor = ListeningSocket::Create(socketAddress, 10000); - if (listeningSocketFileDescriptor < 0) - { - throw std::runtime_error("Error creating listening socket"); - } - - epoll_event event; - event.data.fd = listeningSocketFileDescriptor; - event.events = EPOLLIN; - if (epoll_ctl(pollFileDescriptor, EPOLL_CTL_ADD, listeningSocketFileDescriptor, &event) < 0) - { - throw std::runtime_error("Error registering listening socket with epoll facilities"); - } - - std::stringstream ss; - ss << "Listening on port " << serverConfiguration.GetPort(); - logger.Info(ss.str()); -} - -HttpServer::~HttpServer() -{ - if (listeningSocketFileDescriptor >= 0) - { - close(listeningSocketFileDescriptor); - } - - if (pollFileDescriptor >= 0) - { - close(pollFileDescriptor); - } -} \ No newline at end of file diff --git a/src/server/server.hpp b/src/server/server.hpp index 171a1b6..a70dd6c 100755 --- a/src/server/server.hpp +++ b/src/server/server.hpp @@ -1,28 +1,47 @@ #pragma once +#include "../http/request.hpp" +#include "../http/response.hpp" #include "../logger.hpp" -#include "connectionoperator.hpp" +#include "../middleware/middleware.hpp" +#include "../scheduler.hpp" +#include "eventtype.hpp" #include +#include -class HttpServer +namespace Server { -private: - Logger & logger; - int pollFileDescriptor; - sockaddr_in socketAddress; - int listeningSocketFileDescriptor; - std::map> socketWriteMap; - ConnectionOperator connectionOperator; - bool isOpen; + class HttpServer + { + private: + Logger & logger; + Scheduler scheduler; + ScheduleHelper scheduleHelper; - void HandleEpollInEvent(int fd); + sockaddr_in socketAddress; + int listeningSocketFileDescriptor; - void HandleEpollOutEvent(int fd); + std::map fileDescriptorEventTypeMap; + std::map socketRequestMap; + std::map socketResponseMap; -public: - void Execute(); - - HttpServer(Logger & logger, ServerConfiguration const & serverConfiguration); - ~HttpServer(); - HttpServer(HttpServer & other) = delete; - HttpServer & operator=(HttpServer & other) = delete; -}; \ No newline at end of file + Middleware::Middleware middleware; + + bool isOpen; + + void RegisterNewConnection(int const listeningSocketFd); + + void ReadRequest(int const socketFd); + + void HandleMiddlewareEvent(int const eventFd); + + void HandleMiddlewareResponse(Middleware::MiddlewareResult const & middlewareResult); + + void WriteResponse(int const socketFd); + + public: + void Execute(); + + HttpServer(Logger & logger, Configuration const & serverConfiguration); + ~HttpServer(); + }; +} \ No newline at end of file diff --git a/src/server/socket.cpp b/src/server/socket.cpp index 8ee3f24..2e0657d 100644 --- a/src/server/socket.cpp +++ b/src/server/socket.cpp @@ -2,7 +2,7 @@ #include #include -namespace ListeningSocket +namespace Server::ListeningSocket { int Create(sockaddr_in & socketAddress, int const connectionLimit) { @@ -40,7 +40,7 @@ namespace ListeningSocket } } -namespace Socket +namespace Server::Socket { std::vector ReadBytes(int fd, size_t limit) { diff --git a/src/server/socket.hpp b/src/server/socket.hpp index 4f62227..848f763 100644 --- a/src/server/socket.hpp +++ b/src/server/socket.hpp @@ -3,12 +3,12 @@ #include #include -namespace ListeningSocket +namespace Server::ListeningSocket { int Create(sockaddr_in & socketAddress, int const connectionLimit); } -namespace Socket +namespace Server::Socket { std::vector ReadBytes(int fd, size_t limit);