Rewrite middleware to support epoll

This commit is contained in:
2019-06-29 12:46:29 +02:00
parent 72a6a745ff
commit 4d90d18660
27 changed files with 670 additions and 318 deletions

View File

@@ -17,31 +17,30 @@ namespace Http
return static_cast<T>(-1); return static_cast<T>(-1);
} }
Http::Request Request::Deserialize(std::vector<char> const & bytes) void Request::Deserialize(std::vector<char> const & bytes, Logger & logger)
{ {
// TODO serialize more than just the start // TODO serialize more than just the start
Http::Request request;
std::stringstream ss(std::string(bytes.begin(), bytes.end())); std::stringstream ss(std::string(bytes.begin(), bytes.end()));
std::string requestTypeString; std::string requestTypeString;
ss >> requestTypeString; ss >> requestTypeString;
request.type = ToEnum<HttpRequest::Type>(requestTypeString, HttpRequest::typeStrings); type = ToEnum<HttpRequest::Type>(requestTypeString, HttpRequest::typeStrings);
if (request.type == HttpRequest::Type::UNKNOWN) if (type == HttpRequest::Type::UNKNOWN)
{ {
throw std::runtime_error("Bad request type"); logger.Error("Request::Deserialize: Bad request type");
return;
} }
std::string rawUrl; std::string rawUrl;
ss >> rawUrl; ss >> rawUrl;
if(!request.url.TryParseFromUrlString(rawUrl)) if(!url.TryParseFromUrlString(rawUrl))
{ {
throw std::runtime_error("Bad url in request"); logger.Error("Request::Deserialize: Bad url in request");
type = HttpRequest::Type::UNKNOWN;
return;
} }
std::string httpProtocolString; std::string httpProtocolString;
ss >> httpProtocolString; ss >> httpProtocolString;
return request;
} }
} }

View File

@@ -1,5 +1,6 @@
#pragma once #pragma once
#include "../constants/httprequest.hpp" #include "../constants/httprequest.hpp"
#include "../logger.hpp"
#include "url.hpp" #include "url.hpp"
namespace Http namespace Http
@@ -9,6 +10,6 @@ namespace Http
HttpRequest::Type type; HttpRequest::Type type;
Url url; Url url;
static Request Deserialize(std::vector<char> const & bytes); void Deserialize(std::vector<char> const & bytes, Logger & logger);
}; };
} }

View File

@@ -11,11 +11,21 @@ void Logger::Success(const std::string & s)
Log("SUCCESS", s); Log("SUCCESS", s);
} }
void Logger::Fatal(const std::string & s)
{
Log("FATAL", s);
}
void Logger::Error(const std::string & s) void Logger::Error(const std::string & s)
{ {
Log("ERROR", s); Log("ERROR", s);
} }
void Logger::Warning(const std::string & s)
{
Log("WARNING", s);
}
void Logger::Info(const std::string & s) void Logger::Info(const std::string & s)
{ {
Log("INFO", s); Log("INFO", s);

View File

@@ -5,7 +5,9 @@ class Logger
{ {
public: public:
void Success(const std::string & s); void Success(const std::string & s);
void Fatal(const std::string & s);
void Error(const std::string & s); void Error(const std::string & s);
void Warning(const std::string & s);
void Info(const std::string & s); void Info(const std::string & s);
void Debug(const std::string & s); void Debug(const std::string & s);

View File

@@ -5,9 +5,10 @@
int main(int argc, char ** argv) int main(int argc, char ** argv)
{ {
ServerConfiguration serverConfiguration;
Logger logger; Logger logger;
HttpServer httpServer(logger, serverConfiguration);
Server::Configuration serverConfiguration;
Server::HttpServer httpServer(logger, serverConfiguration);
httpServer.Execute(); httpServer.Execute();

View File

@@ -2,6 +2,17 @@
namespace Middleware namespace Middleware
{ {
State::State(int const socketFileDescriptor, Http::Request const & _request, Http::Response & _response)
: finished(false),
awaiting(false),
socketFd(socketFileDescriptor),
awaitingFd(-1),
request(_request),
response(_response),
middlewareIndex(0)
{
}
BaseMiddleware::BaseMiddleware(Logger & _logger) BaseMiddleware::BaseMiddleware(Logger & _logger)
: logger(_logger) : logger(_logger)
{ {

View File

@@ -2,18 +2,35 @@
#include "../http/request.hpp" #include "../http/request.hpp"
#include "../http/response.hpp" #include "../http/response.hpp"
#include "../logger.hpp" #include "../logger.hpp"
#include "../scheduler.hpp"
#include <cstdint> #include <cstdint>
#include <string> #include <string>
namespace Middleware namespace Middleware
{ {
struct State
{
bool finished;
bool awaiting;
int const socketFd;
int awaitingFd;
Http::Request const & request;
Http::Response & response;
size_t middlewareIndex;
State(int const socketFileDescriptor, Http::Request const & request, Http::Response & response);
};
class BaseMiddleware class BaseMiddleware
{ {
protected: protected:
Logger & logger; Logger & logger;
public: public:
virtual void HandleRequest(Http::Request const & request, Http::Response & response) = 0; virtual void HandleRequest(State & state, ScheduleHelper & scheduleHelper) = 0;
BaseMiddleware(Logger & logger); BaseMiddleware(Logger & logger);
virtual ~BaseMiddleware() = default; virtual ~BaseMiddleware() = default;

View File

@@ -0,0 +1,96 @@
#include "middleware.hpp"
#include "notfound.hpp"
#include "staticcontent.hpp"
#include "unimplemented.hpp"
#include <unistd.h>
namespace Middleware
{
MiddlewareResult::MiddlewareResult(bool const _writeResult, int const _socketFd)
: writeResult(_writeResult),
socketFd(_socketFd)
{
}
Middleware::Middleware(Logger & _logger)
: logger(_logger)
{
}
MiddlewareResult Middleware::Loop(State & state, ScheduleHelper & scheduleHelper)
{
for(size_t i = state.middlewareIndex; i < middlewares.size(); ++i)
{
middlewares[i]->HandleRequest(state, scheduleHelper);
if (state.finished)
{
break;
}
if (state.awaiting)
{
state.middlewareIndex = i;
eventFdToSocketFdMap[state.awaitingFd] = state.socketFd;
return MiddlewareResult (false, state.socketFd);
}
}
int const socketFd = state.socketFd;
auto stateResult = socketFdToStateMap.find(socketFd);
if (stateResult != socketFdToStateMap.end())
{
socketFdToStateMap.erase(stateResult);
// The state parameter is now invalid!!
}
return MiddlewareResult(true, socketFd);
}
MiddlewareResult Middleware::Start(int const socketFd, ScheduleHelper & scheduleHelper, Http::Request const & request, Http::Response & response)
{
auto & statePtr = socketFdToStateMap[socketFd];
statePtr = std::make_unique<State>(socketFd, request, response);
return Loop(*statePtr, scheduleHelper);
}
MiddlewareResult Middleware::Continue(int const fd, ScheduleHelper & scheduleHelper)
{
auto socketFdResult = eventFdToSocketFdMap.find(fd);
if (socketFdResult == eventFdToSocketFdMap.end())
{
close(fd);
logger.Warning("Dangeling filedescriptor closed in middleware");
return MiddlewareResult(false, -1);
}
int const socketFd = socketFdResult->second;
eventFdToSocketFdMap.erase(socketFdResult);
auto stateResult = socketFdToStateMap.find(socketFd);
if (stateResult == socketFdToStateMap.end())
{
close(fd);
close(socketFd);
logger.Warning("Stateless socket file descriptor and associated event file descriptor closed");
return MiddlewareResult(false, -1);
}
return Loop(*(stateResult->second), scheduleHelper);
}
void Middleware::UseMiddleware(std::unique_ptr<BaseMiddleware> && middleware)
{
middlewares.emplace_back(std::move(middleware));
}
Middleware Middleware::CreateWithStaticFile(Logger & logger, std::string const & rootDirectory)
{
Middleware middleware(logger);
middleware.UseMiddleware(std::make_unique<StaticContent>(logger, rootDirectory));
middleware.UseMiddleware(std::make_unique<NotFound>(logger));
middleware.UseMiddleware(std::make_unique<Unimplemented>(logger));
return middleware;
}
}

View File

@@ -0,0 +1,41 @@
#pragma once
#include "../http/request.hpp"
#include "../http/response.hpp"
#include "../logger.hpp"
#include "../scheduler.hpp"
#include "base.hpp"
#include <memory>
#include <string>
#include <vector>
namespace Middleware
{
struct MiddlewareResult
{
bool const writeResult;
int const socketFd;
MiddlewareResult(bool const writeResult, int const socketFd);
};
class Middleware
{
private:
Logger & logger;
std::vector<std::unique_ptr<BaseMiddleware>> middlewares;
std::map<int, std::unique_ptr<State>> socketFdToStateMap;
std::map<int, int> eventFdToSocketFdMap;
Middleware(Logger & logger);
MiddlewareResult Loop(State & state, ScheduleHelper & scheduleHelper);
public:
MiddlewareResult Start(int const socketFd, ScheduleHelper & scheduleHelper, Http::Request const & request, Http::Response & response);
MiddlewareResult Continue(int const fd, ScheduleHelper & scheduleHelper);
void UseMiddleware(std::unique_ptr<BaseMiddleware> && middleware);
static Middleware CreateWithStaticFile(Logger & logger, std::string const & rootDirectory);
};
}

View File

@@ -4,30 +4,32 @@
namespace Middleware namespace Middleware
{ {
void NotFound::HandleRequest(Http::Request const & request, Http::Response & response) void NotFound::HandleRequest(State & state, ScheduleHelper & scheduleHelper)
{ {
if (response.code != HttpResponse::Code::UNKNOWN || if (state.response.code != HttpResponse::Code::UNKNOWN ||
!(request.type == HttpRequest::Type::GET || request.type == HttpRequest::Type::HEAD)) !(state.request.type == HttpRequest::Type::GET || state.request.type == HttpRequest::Type::HEAD))
{ {
return; return;
} }
response.code = HttpResponse::Code::NOT_FOUND; state.response.code = HttpResponse::Code::NOT_FOUND;
if (request.type == HttpRequest::Type::GET) if (state.request.type == HttpRequest::Type::GET)
{ {
response.contentType = Http::GetMimeType(Http::FileType::HTML); state.response.contentType = Http::GetMimeType(Http::FileType::HTML);
std::stringstream ss; std::stringstream ss;
ss << "<!DOCTYPE html><html><head></head><body>"; ss << "<!DOCTYPE html><html><head></head><body>";
ss << "<h1>404 - File Not Found</h1>"; ss << "<h1>404 - File Not Found</h1>";
ss << "<p>File: " << request.url.GetPath() << "<p>"; ss << "<p>File: " << state.request.url.GetPath() << "<p>";
ss << "</body></html>"; ss << "</body></html>";
auto responseContent = ss.str(); auto responseContent = ss.str();
response.content.insert(response.content.begin(), state.response.content.insert(state.response.content.begin(),
responseContent.begin(), responseContent.begin(),
responseContent.end()); responseContent.end());
} }
state.finished = true;
} }
NotFound::NotFound(Logger & _logger) NotFound::NotFound(Logger & _logger)

View File

@@ -6,7 +6,7 @@ namespace Middleware
class NotFound : public BaseMiddleware class NotFound : public BaseMiddleware
{ {
public: public:
void HandleRequest(Http::Request const & request, Http::Response & Response) override; void HandleRequest(State & state, ScheduleHelper & scheduleHelper) override;
NotFound(Logger & logger); NotFound(Logger & logger);
}; };

View File

@@ -4,6 +4,7 @@
#include <ios> #include <ios>
#include "staticcontent.hpp" #include "staticcontent.hpp"
#include <sstream> #include <sstream>
#include <unistd.h>
namespace Middleware namespace Middleware
{ {
@@ -48,17 +49,17 @@ namespace Middleware
return false; return false;
} }
void StaticContent::HandleRequest(Http::Request const & request, Http::Response & response) void StaticContent::HandleRequest(State & state, ScheduleHelper & scheduleHelper)
{ {
if (!(request.type == HttpRequest::Type::GET || request.type == HttpRequest::Type::HEAD)) if (!(state.request.type == HttpRequest::Type::GET || state.request.type == HttpRequest::Type::HEAD))
{ {
return; return;
} }
std::string path; std::string path;
if (request.url.HasPath()) if (state.request.url.HasPath())
{ {
path = root + request.url.GetPath(); path = root + state.request.url.GetPath();
} }
else else
{ {
@@ -66,20 +67,31 @@ namespace Middleware
path = root + "/index.html"; path = root + "/index.html";
} }
if (ContainsDoubleDots(request.url.GetPath())) if (ContainsDoubleDots(state.request.url.GetPath()))
{ {
// We cannot deal with this, we are not going to bother checking if // We cannot deal with this, we are not going to bother checking if
// this double dot escapes our root directory // this double dot escapes our root directory
return; return;
} }
if (request.type == HttpRequest::Type::GET && !TryReadAllBytes(path, response.content)) if (access(path.c_str(), F_OK) != 0)
{ {
// File does not exist
return; return;
} }
response.code = HttpResponse::Code::OK; state.response.code = HttpResponse::Code::OK;
response.contentType = Http::GetMimeType(path); state.response.contentType = Http::GetMimeType(path);
// Regular file descriptors are not supported by epoll, so we have to "just read it"
if (state.request.type == HttpRequest::Type::GET && !TryReadAllBytes(path, state.response.content))
{
state.response.code = HttpResponse::Code::INTERNAL_SERVER_ERROR;
state.response.contentType = Http::GetMimeType("");
}
// HEAD request
state.finished = true;
return; return;
} }

View File

@@ -12,7 +12,7 @@ namespace Middleware
std::string root; std::string root;
public: public:
virtual void HandleRequest(Http::Request const & request, Http::Response & response) override; virtual void HandleRequest(State & state, ScheduleHelper & scheduleHelper) override;
StaticContent(Logger & logger, std::string const & staticFileRoot); StaticContent(Logger & logger, std::string const & staticFileRoot);
}; };

View File

@@ -0,0 +1,15 @@
#include "unimplemented.hpp"
namespace Middleware
{
void Unimplemented::HandleRequest(State & state, ScheduleHelper & scheduleHelper)
{
state.finished = true;
state.response.code = HttpResponse::Code::NOT_IMPLEMENTED;
}
Unimplemented::Unimplemented(Logger & logger)
: BaseMiddleware(logger)
{
}
}

View File

@@ -0,0 +1,13 @@
#pragma once
#include "base.hpp"
namespace Middleware
{
class Unimplemented : public BaseMiddleware
{
public:
virtual void HandleRequest(State & state, ScheduleHelper & scheduleHelper) override;
Unimplemented(Logger & logger);
};
}

89
src/scheduler.cpp Normal file
View File

@@ -0,0 +1,89 @@
#include <errno.h>
#include "scheduler.hpp"
#include <unistd.h>
void Scheduler::ScheduleFileDescriptor(int fd, int eventFlags, bool isReschedule)
{
if (fd < 0)
{
logger.Warning("Attempted to schedule a non valid filedescriptor");
}
epoll_event event;
event.data.fd = fd;
event.events = eventFlags;
int const epollCtl = isReschedule ? EPOLL_CTL_MOD : EPOLL_CTL_ADD;
if (epoll_ctl(epollFileDescriptor, epollCtl, fd, &event) < 0)
{
logger.Fatal(std::to_string(errno));
throw std::runtime_error("Error registering filedescriptor with epoll facilities");
}
}
std::vector<int> Scheduler::WaitForEvents()
{
ScheduleHelper scheduleHelper(*this);
int const eventsHappened = epoll_wait(epollFileDescriptor, epollEvents.data(), static_cast<int>(epollEvents.size()), -1);
std::vector<int> readyFileDescriptors;
if (eventsHappened > 0)
{
readyFileDescriptors.resize(eventsHappened);
for (int i = 0; i < eventsHappened; ++i)
{
readyFileDescriptors[i] = epollEvents[i].data.fd;
}
}
return readyFileDescriptors;
}
ScheduleHelper Scheduler::GetScheduleHelper()
{
return ScheduleHelper(*this);
}
Scheduler::Scheduler(Logger & _logger)
: logger(_logger)
{
epollFileDescriptor = epoll_create1(0);
if (epollFileDescriptor < 0)
{
throw std::runtime_error("Error creating epoll file descriptor");
}
epollEvents.resize(512);
}
Scheduler::~Scheduler()
{
if (epollFileDescriptor >= 0)
{
close(epollFileDescriptor);
}
}
void ScheduleHelper::ScheduleReadFileDescriptor(int fd, bool oneshot, bool isReschedule)
{
int flags = EPOLLIN;
if (oneshot)
{
flags |= EPOLLONESHOT;
}
scheduler.ScheduleFileDescriptor(fd, flags, isReschedule);
}
void ScheduleHelper::ScheduleWriteFileDescriptor(int fd, bool oneshot, bool isReschedule)
{
int flags = EPOLLOUT;
if (oneshot)
{
flags |= EPOLLONESHOT;
}
scheduler.ScheduleFileDescriptor(fd, flags, isReschedule);
}
ScheduleHelper::ScheduleHelper(Scheduler & _scheduler)
: scheduler(_scheduler)
{
}

47
src/scheduler.hpp Normal file
View File

@@ -0,0 +1,47 @@
#pragma once
#include "logger.hpp"
#include <map>
#include <sys/epoll.h>
#include <vector>
class ScheduleHelper; // Forward declare
//using callbackFunction = void (*)(int, ScheduleHelper);
class Scheduler
{
friend class ScheduleHelper;
private:
Logger & logger;
int epollFileDescriptor;
// std::map<int, callbackFunction> fileDescriptorToCallbackMap;
std::vector<epoll_event> epollEvents;
void ScheduleFileDescriptor(int fd, int eventFlags, bool isReschedule);
public:
// Blocks until an event happens
// Returns a vector filedescriptors that are ready for processing
std::vector<int> WaitForEvents();
ScheduleHelper GetScheduleHelper();
Scheduler(Logger & logger);
~Scheduler();
Scheduler(Scheduler & other) = delete;
Scheduler(Scheduler && other) = delete;
Scheduler & operator=(Scheduler & other) = delete;
Scheduler & operator=(Scheduler && other) = delete;
};
class ScheduleHelper
{
private:
Scheduler & scheduler;
public:
void ScheduleReadFileDescriptor(int fd, bool oneshot = true, bool isReschedule = false);
void ScheduleWriteFileDescriptor(int fd, bool oneshot = true, bool isReschedule = false);
ScheduleHelper(Scheduler & scheduler);
};

View File

@@ -1,44 +1,47 @@
#include "configuration.hpp" #include "configuration.hpp"
int ServerConfiguration::GetMajorVersion() const namespace Server
{ {
int Configuration::GetMajorVersion() const
{
return 0; return 0;
} }
int ServerConfiguration::GetMinorVersion() const int Configuration::GetMinorVersion() const
{ {
return 1; return 1;
} }
std::string const & ServerConfiguration::GetWwwRoot() const std::string const & Configuration::GetWwwRoot() const
{ {
return wwwRoot; return wwwRoot;
} }
std::string const & ServerConfiguration::GetServerName() const std::string const & Configuration::GetServerName() const
{ {
return serverName; return serverName;
} }
int ServerConfiguration::GetPort() const int Configuration::GetPort() const
{ {
return port; return port;
} }
bool ServerConfiguration::IsValid() const bool Configuration::IsValid() const
{ {
return isValid; return isValid;
} }
bool ServerConfiguration::LoadFromFile(std::string const & filePath) bool Configuration::LoadFromFile(std::string const & filePath)
{ {
// TODO implement // TODO implement
return false; return false;
} }
ServerConfiguration::ServerConfiguration() Configuration::Configuration()
: wwwRoot("./www"), : wwwRoot("./www"),
serverName("http-server"), serverName("http-server"),
port(8080) port(8080)
{ {
}
} }

View File

@@ -1,15 +1,17 @@
#pragma once #pragma once
#include <string> #include <string>
class ServerConfiguration namespace Server
{ {
private: class Configuration
{
private:
std::string wwwRoot; std::string wwwRoot;
std::string serverName; std::string serverName;
int port; int port;
bool isValid; bool isValid;
public: public:
int GetMajorVersion() const; int GetMajorVersion() const;
int GetMinorVersion() const; int GetMinorVersion() const;
std::string const & GetWwwRoot() const; std::string const & GetWwwRoot() const;
@@ -19,9 +21,10 @@ public:
bool LoadFromFile(std::string const & filePath); bool LoadFromFile(std::string const & filePath);
ServerConfiguration(); Configuration();
~ServerConfiguration() = default; ~Configuration() = default;
ServerConfiguration(ServerConfiguration & other) = delete; Configuration(Configuration & other) = delete;
ServerConfiguration(ServerConfiguration && other) = delete; Configuration(Configuration && other) = delete;
}; };
}

View File

@@ -1,66 +0,0 @@
#include "../middleware/notfound.hpp"
#include "../middleware/staticcontent.hpp"
#include "../logger.hpp"
#include "configuration.hpp"
#include "connectionoperator.hpp"
#include <cstdio>
#include <sstream>
std::vector<char> ConnectionOperator::HandleNewConnection(int fd)
{
auto requestBytes = Socket::ReadBytes(fd, 512);
Http::Request request;
Http::Response response;
try
{
request = Http::Request::Deserialize(requestBytes);
}
catch (std::runtime_error & e)
{
std::stringstream ss;
ss << "Error during parsing of request <";
ss << e.what();
ss << '>';
logger.Error(ss.str());
response.code = HttpResponse::Code::BAD_REQUEST;
return response.Serialize();
}
for(size_t i = 0; i < middlewares.size(); ++i)
{
middlewares[i]->HandleRequest(request, response);
}
if (response.code == HttpResponse::Code::UNKNOWN)
{
std::stringstream ss;
ss << "Unhandled ";
ss << HttpRequest::typeStrings[static_cast<int>(request.type)];
ss << " request for file <";
ss << request.url.GetPath();
ss << '>';
logger.Error(ss.str());
response.code = HttpResponse::Code::NOT_IMPLEMENTED;
return response.Serialize();
}
return response.Serialize();
}
ConnectionOperator::ConnectionOperator(Logger & _logger, ServerConfiguration const & serverConfiguration)
: logger(_logger)
{
// Base static file server
auto const & staticFileRoot = serverConfiguration.GetWwwRoot();
if (staticFileRoot.size() > 0)
{
middlewares.emplace_back(std::make_unique<Middleware::StaticContent>(_logger, staticFileRoot));
}
// ALWAYS LAST!
middlewares.emplace_back(std::make_unique<Middleware::NotFound>(_logger));
}

View File

@@ -1,19 +0,0 @@
#pragma once
#include "../logger.hpp"
#include "../middleware/base.hpp"
#include <memory>
#include "socket.hpp"
#include <string>
#include <vector>
class ConnectionOperator
{
private:
Logger & logger;
std::vector<std::unique_ptr<Middleware::BaseMiddleware>> middlewares;
public:
std::vector<char> HandleNewConnection(int fd);
ConnectionOperator(Logger & logger, ServerConfiguration const & serverConfiguration);
};

12
src/server/eventtype.hpp Normal file
View File

@@ -0,0 +1,12 @@
#pragma once
namespace Server
{
enum class EventType
{
Unknown = -1,
NewConnection = 0,
ReadRequest = 1,
WriteResponse
};
}

View File

@@ -2,119 +2,163 @@
#include "configuration.hpp" #include "configuration.hpp"
#include "server.hpp" #include "server.hpp"
#include <set> #include <set>
#include "socket.hpp"
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <sys/epoll.h> #include <sys/epoll.h>
#include <unistd.h> #include <unistd.h>
void HttpServer::HandleEpollInEvent(int fd) namespace Server
{ {
if (fd != listeningSocketFileDescriptor) void HttpServer::RegisterNewConnection(int const listeningSocketFd)
{ {
logger.Info("EPOLLIN Attempted to handle a non registered file descriptor");
return;
}
unsigned sockaddrSize = sizeof(sockaddr_in); unsigned sockaddrSize = sizeof(sockaddr_in);
int const connectionFileDescriptor = accept( int const clientFd = accept(
listeningSocketFileDescriptor, listeningSocketFd,
reinterpret_cast<sockaddr *>(&socketAddress), reinterpret_cast<sockaddr *>(&socketAddress),
&sockaddrSize); &sockaddrSize);
socketWriteMap[connectionFileDescriptor] = connectionOperator.HandleNewConnection(connectionFileDescriptor); if (clientFd < 0)
epoll_event event;
event.data.fd = connectionFileDescriptor;
event.events = EPOLLOUT | EPOLLONESHOT;
if (epoll_ctl(pollFileDescriptor, EPOLL_CTL_ADD, connectionFileDescriptor, &event) < 0)
{ {
logger.Error("Error registering file descriptor for EPOLLOUT"); logger.Info("Connection dropped on accepting it");
}
}
void HttpServer::HandleEpollOutEvent(int fd)
{
auto result = socketWriteMap.find(fd);
if (result == socketWriteMap.end())
{
logger.Error("EPOLLOUT Event for unexpected fd");
return; return;
} }
Socket::WriteBytes(fd, result->second); scheduleHelper.ScheduleReadFileDescriptor(clientFd);
close(fd); fileDescriptorEventTypeMap[clientFd] = EventType::ReadRequest;
socketWriteMap.erase(result); }
}
void HttpServer::Execute() void HttpServer::ReadRequest(int const socketFd)
{ {
std::vector<epoll_event> epollEvents; auto bytes = Socket::ReadBytes(socketFd, 512);
epollEvents.resize(1000); if (bytes.size() < 1)
{
logger.Info("Could not read anything from socket");
return;
}
auto & request = socketRequestMap[socketFd];
request.Deserialize(bytes, logger);
auto & response = socketResponseMap[socketFd];
auto middlewareResult = middleware.Start(socketFd, scheduleHelper, request, response);
HandleMiddlewareResponse(middlewareResult);
}
void HttpServer::HandleMiddlewareEvent(int const eventFd)
{
auto middlewareResult = middleware.Continue(eventFd, scheduleHelper);
HandleMiddlewareResponse(middlewareResult);
}
void HttpServer::HandleMiddlewareResponse(Middleware::MiddlewareResult const & middlewareResult)
{
if (middlewareResult.writeResult)
{
fileDescriptorEventTypeMap[middlewareResult.socketFd] = EventType::WriteResponse;
scheduleHelper.ScheduleWriteFileDescriptor(middlewareResult.socketFd, true, true);
}
}
void HttpServer::WriteResponse(int const socketFd)
{
auto requestResult = socketRequestMap.find(socketFd);
if (requestResult != socketRequestMap.end())
{
socketRequestMap.erase(requestResult);
}
auto responseResult = socketResponseMap.find(socketFd);
Socket::WriteBytes(socketFd, responseResult->second.Serialize());
if (responseResult != socketResponseMap.end())
{
socketResponseMap.erase(responseResult);
}
}
void HttpServer::Execute()
{
while(isOpen) while(isOpen)
{ {
int eventsHappened = epoll_wait(pollFileDescriptor, epollEvents.data(), static_cast<int>(epollEvents.size()), -1); auto events = scheduler.WaitForEvents();
for (int i = 0; i < eventsHappened; ++i) for (size_t i = 0; i < events.size(); ++i)
{ {
int const fd = epollEvents[i].data.fd; int const fd = events[i];
if (epollEvents[i].events & EPOLLIN)
{
HandleEpollInEvent(fd);
}
else
{
HandleEpollOutEvent(fd);
}
}
}
}
HttpServer::HttpServer(Logger & _logger, ServerConfiguration const & serverConfiguration) auto eventTypeMapIter = fileDescriptorEventTypeMap.find(fd);
if (eventTypeMapIter == fileDescriptorEventTypeMap.end())
{
HandleMiddlewareEvent(fd);
continue;
}
if (eventTypeMapIter->second != EventType::NewConnection)
{
fileDescriptorEventTypeMap.erase(eventTypeMapIter);
}
switch(eventTypeMapIter->second)
{
case EventType::NewConnection:
RegisterNewConnection(fd);
break;
case EventType::ReadRequest:
ReadRequest(fd);
break;
case EventType::WriteResponse:
WriteResponse(fd);
close(fd);
break;
case EventType::Unknown:
default:
close(fd);
logger.Error("FD with unknown EventType fired and was closed");
break;
}
}
}
}
HttpServer::HttpServer(Logger & _logger, Configuration const & serverConfiguration)
: logger(_logger), : logger(_logger),
pollFileDescriptor(-1), scheduler(_logger),
scheduleHelper(scheduler),
listeningSocketFileDescriptor(-1), listeningSocketFileDescriptor(-1),
connectionOperator(_logger, serverConfiguration), middleware(Middleware::Middleware::CreateWithStaticFile(_logger, serverConfiguration.GetWwwRoot())),
isOpen(true) isOpen(true)
{
pollFileDescriptor = epoll_create1(0);
if (pollFileDescriptor < 0)
{ {
throw std::runtime_error("Error creating epoll file descriptor");
}
socketAddress.sin_family = AF_INET; socketAddress.sin_family = AF_INET;
socketAddress.sin_addr.s_addr = INADDR_ANY; socketAddress.sin_addr.s_addr = INADDR_ANY;
socketAddress.sin_port = htons(serverConfiguration.GetPort()); socketAddress.sin_port = htons(serverConfiguration.GetPort());
listeningSocketFileDescriptor = ListeningSocket::Create(socketAddress, 10000); listeningSocketFileDescriptor = ListeningSocket::Create(socketAddress, 1000);
if (listeningSocketFileDescriptor < 0) if (listeningSocketFileDescriptor < 0)
{ {
throw std::runtime_error("Error creating listening socket"); throw std::runtime_error("Error creating listening socket");
} }
epoll_event event; scheduleHelper.ScheduleReadFileDescriptor(listeningSocketFileDescriptor, false);
event.data.fd = listeningSocketFileDescriptor; fileDescriptorEventTypeMap[listeningSocketFileDescriptor] = EventType::NewConnection;
event.events = EPOLLIN;
if (epoll_ctl(pollFileDescriptor, EPOLL_CTL_ADD, listeningSocketFileDescriptor, &event) < 0)
{
throw std::runtime_error("Error registering listening socket with epoll facilities");
}
std::stringstream ss; std::stringstream ss;
ss << "Listening on port " << serverConfiguration.GetPort(); ss << "Listening on port " << serverConfiguration.GetPort();
logger.Info(ss.str()); logger.Info(ss.str());
} }
HttpServer::~HttpServer() HttpServer::~HttpServer()
{ {
if (listeningSocketFileDescriptor >= 0) if (listeningSocketFileDescriptor >= 0)
{ {
close(listeningSocketFileDescriptor); close(listeningSocketFileDescriptor);
} }
if (pollFileDescriptor >= 0) for (auto & item : fileDescriptorEventTypeMap)
{ {
close(pollFileDescriptor); close(item.first);
}
} }
} }

View File

@@ -1,28 +1,47 @@
#pragma once #pragma once
#include "../http/request.hpp"
#include "../http/response.hpp"
#include "../logger.hpp" #include "../logger.hpp"
#include "connectionoperator.hpp" #include "../middleware/middleware.hpp"
#include "../scheduler.hpp"
#include "eventtype.hpp"
#include <map> #include <map>
#include <netinet/in.h>
class HttpServer namespace Server
{ {
private: class HttpServer
{
private:
Logger & logger; Logger & logger;
int pollFileDescriptor; Scheduler scheduler;
ScheduleHelper scheduleHelper;
sockaddr_in socketAddress; sockaddr_in socketAddress;
int listeningSocketFileDescriptor; int listeningSocketFileDescriptor;
std::map<int, std::vector<char>> socketWriteMap;
ConnectionOperator connectionOperator; std::map<int, EventType> fileDescriptorEventTypeMap;
std::map<int, Http::Request> socketRequestMap;
std::map<int, Http::Response> socketResponseMap;
Middleware::Middleware middleware;
bool isOpen; bool isOpen;
void HandleEpollInEvent(int fd); void RegisterNewConnection(int const listeningSocketFd);
void HandleEpollOutEvent(int fd); void ReadRequest(int const socketFd);
public: void HandleMiddlewareEvent(int const eventFd);
void HandleMiddlewareResponse(Middleware::MiddlewareResult const & middlewareResult);
void WriteResponse(int const socketFd);
public:
void Execute(); void Execute();
HttpServer(Logger & logger, ServerConfiguration const & serverConfiguration); HttpServer(Logger & logger, Configuration const & serverConfiguration);
~HttpServer(); ~HttpServer();
HttpServer(HttpServer & other) = delete; };
HttpServer & operator=(HttpServer & other) = delete; }
};

View File

@@ -2,7 +2,7 @@
#include <stdexcept> #include <stdexcept>
#include <unistd.h> #include <unistd.h>
namespace ListeningSocket namespace Server::ListeningSocket
{ {
int Create(sockaddr_in & socketAddress, int const connectionLimit) int Create(sockaddr_in & socketAddress, int const connectionLimit)
{ {
@@ -40,7 +40,7 @@ namespace ListeningSocket
} }
} }
namespace Socket namespace Server::Socket
{ {
std::vector<char> ReadBytes(int fd, size_t limit) std::vector<char> ReadBytes(int fd, size_t limit)
{ {

View File

@@ -3,12 +3,12 @@
#include <sys/socket.h> #include <sys/socket.h>
#include <vector> #include <vector>
namespace ListeningSocket namespace Server::ListeningSocket
{ {
int Create(sockaddr_in & socketAddress, int const connectionLimit); int Create(sockaddr_in & socketAddress, int const connectionLimit);
} }
namespace Socket namespace Server::Socket
{ {
std::vector<char> ReadBytes(int fd, size_t limit); std::vector<char> ReadBytes(int fd, size_t limit);