Files
wassembly/src/compile/compiler.cpp

433 lines
12 KiB
C++

#include <compile/compiler.hpp>
#include <compile/errors.hpp>
#include <execute/bytecode.hpp>
#include <stdexcept>
#include <utils.hpp>
namespace Compile
{
int GetRequiredNumberOfArguments(Token::OperandType const type)
{
switch (type)
{
case Token::OperandType::AddInteger:
case Token::OperandType::SubtractInteger:
case Token::OperandType::DivideInteger:
case Token::OperandType::MultiplyInteger:
case Token::OperandType::ShiftIntegerLeft:
case Token::OperandType::ShiftIntegerRight:
return 3;
case Token::OperandType::LessThanInteger:
case Token::OperandType::GreaterThanInteger:
case Token::OperandType::EqualInteger:
case Token::OperandType::SetInteger:
return 2;
case Token::OperandType::Jump:
case Token::OperandType::CallFunction:
case Token::OperandType::Interrupt:
case Token::OperandType::PushInteger:
case Token::OperandType::PopInteger:
return 1;
default:
std::printf("WARNING: returning default argument length of 0 for operand type %i\n", static_cast<int>(type));
case Token::OperandType::ReturnFromFunction:
case Token::OperandType::ExitProgram:
return 0;
}
}
bool IsArgumentToken(Token::Token const & t)
{
return
t.type == Token::TokenType::ImmediateInteger ||
t.type == Token::TokenType::Register ||
t.type == Token::TokenType::LabelArgument ||
t.type == Token::TokenType::Memory;
}
bool IsReadableToken(Token::Token const & t)
{
return
t.type == Token::TokenType::ImmediateInteger ||
t.type == Token::TokenType::Register ||
t.type == Token::TokenType::Memory;
}
bool IsWriteableToken(Token::Token const & t)
{
return
t.type == Token::TokenType::Register ||
t.type == Token::TokenType::Memory;
}
void ValidateArguments(
std::vector<Token::Token> const & tokens,
std::size_t const operandIndex)
{
auto const operandType = std::get<Token::OperandType>(tokens[operandIndex].data);
switch(operandType)
{
// 2 Read values + 1 write value
case Token::OperandType::AddInteger:
case Token::OperandType::SubtractInteger:
case Token::OperandType::DivideInteger:
case Token::OperandType::MultiplyInteger:
case Token::OperandType::ShiftIntegerLeft:
case Token::OperandType::ShiftIntegerRight:
if (!IsReadableToken(tokens[operandIndex + 1]))
{
throw CompilationError::CreateExpectedImmediateOrRegisterOrMemory(tokens[operandIndex + 1]);
}
if (!IsReadableToken(tokens[operandIndex + 2]))
{
throw CompilationError::CreateExpectedImmediateOrRegisterOrMemory(tokens[operandIndex + 2]);
}
if (!IsWriteableToken(tokens[operandIndex + 3]))
{
throw CompilationError::CreateExpectedRegisterOrMemoryError(tokens[operandIndex + 3]);
}
break;
// 2 Read values
case Token::OperandType::LessThanInteger:
case Token::OperandType::GreaterThanInteger:
case Token::OperandType::EqualInteger:
case Token::OperandType::SetInteger:
if (!IsReadableToken(tokens[operandIndex + 1]))
{
throw CompilationError::CreateExpectedImmediateOrRegisterOrMemory(tokens[operandIndex + 1]);
}
if (!IsReadableToken(tokens[operandIndex + 2]))
{
throw CompilationError::CreateExpectedImmediateOrRegisterOrMemory(tokens[operandIndex + 2]);
}
break;
// 1 Label value
case Token::OperandType::Jump:
case Token::OperandType::CallFunction:
if (tokens[operandIndex + 1].type != Token::TokenType::LabelArgument)
{
throw CompilationError::CreateExpectedLabelError(tokens[operandIndex + 1]);
}
break;
// 1 Read value
case Token::OperandType::Interrupt:
case Token::OperandType::PushInteger:
if (!IsReadableToken(tokens[operandIndex + 1]))
{
throw CompilationError::CreateExpectedImmediateOrRegisterOrMemory(tokens[operandIndex + 1]);
}
break;
// 1 Write value
case Token::OperandType::PopInteger:
if (!IsWriteableToken(tokens[operandIndex + 1]))
{
throw CompilationError::CreateExpectedRegisterOrMemoryError(tokens[operandIndex + 1]);
}
break;
default:
throw std::runtime_error("Unimplemented operandType case in ValidateArguments");
}
}
Execute::RegisterByte GetByteCodeRegister(Token::RegisterType const v)
{
switch(v)
{
case Token::RegisterType::A:
return Execute::RegisterByte::A;
case Token::RegisterType::B:
return Execute::RegisterByte::B;
case Token::RegisterType::C:
return Execute::RegisterByte::C;
case Token::RegisterType::D:
return Execute::RegisterByte::D;
default:
throw std::runtime_error("Unhandled register type in GetByteCodeRegister");
}
}
void Compiler::InsertAsBytes(
Token::Token const & token,
std::vector<std::uint8_t> & bytes)
{
switch(token.type)
{
case Token::TokenType::ImmediateInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::IMMEDIATE_INTEGER));
{
int value = std::get<int>(token.data);
auto const insertionIndex = bytes.size();
bytes.resize(bytes.size() + 4);
Utils::Bytes::Write(value, bytes, insertionIndex);
}
break;
case Token::TokenType::Operand:
{
switch(std::get<Token::OperandType>(token.data))
{
case Token::OperandType::AddInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::ADD_INTEGER));
break;
case Token::OperandType::SubtractInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::SUBTRACT_INTEGER));
break;
case Token::OperandType::DivideInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::DIVIDE_INTEGER));
break;
case Token::OperandType::MultiplyInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::MULTIPLY_INTEGER));
break;
case Token::OperandType::ShiftIntegerLeft:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::SHIFT_LEFT_INTEGER));
break;
case Token::OperandType::ShiftIntegerRight:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::SHIFT_RIGHT_INTEGER));
break;
case Token::OperandType::LessThanInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::LESS_THAN_INTEGER));
break;
case Token::OperandType::GreaterThanInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::GREATER_THAN_INTEGER));
break;
case Token::OperandType::EqualInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::EQUALS_INTEGER));
break;
case Token::OperandType::SetInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::SET_INTEGER));
break;
case Token::OperandType::Jump:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::JUMP));
break;
case Token::OperandType::CallFunction:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::CALL));
break;
case Token::OperandType::Interrupt:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::INTERRUPT));
break;
case Token::OperandType::PushInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::PUSH_INTEGER));
break;
case Token::OperandType::PopInteger:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::POP_INTEGER));
break;
case Token::OperandType::ReturnFromFunction:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::RETURN));
break;
case Token::OperandType::ExitProgram:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::EXIT));
break;
break;
default:
throw std::runtime_error("Unhandled operand type in InsertAsBytes");
}
}
break;
case Token::TokenType::Register:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::REGISTER));
bytes.push_back(static_cast<std::uint8_t>(GetByteCodeRegister(std::get<Token::RegisterType>(token.data))));
break;
case Token::TokenType::StatementEnd:
case Token::TokenType::LabelDefinition:
// NO OP
break;
case Token::TokenType::LabelArgument:
{
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::LABEL));
auto const & label = std::get<std::string>(token.data);
auto const findResult = jumpLabelLocations.find(label);
int jumpLocation = 0;
if (findResult == jumpLabelLocations.end())
{
unresolvedJumpLabels.push_back(std::make_pair(token, bytes.size()));
}
else
{
jumpLocation = findResult->second;
}
auto const insertionIndex = bytes.size();
bytes.resize(bytes.size() + 4);
Utils::Bytes::Write(jumpLocation, bytes, insertionIndex);
}
break;
case Token::TokenType::Memory:
{
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::MEMORY_OP));
switch(token.valueType)
{
case Token::TokenValueType::Register:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::REGISTER));
bytes.push_back(static_cast<std::uint8_t>(GetByteCodeRegister(std::get<Token::RegisterType>(token.data))));
break;
case Token::TokenValueType::Integer:
bytes.push_back(static_cast<std::uint8_t>(Execute::InstructionByte::IMMEDIATE_INTEGER));
{
auto const insertionIndex = bytes.size();
bytes.resize(bytes.size() + 4);
Utils::Bytes::Write(std::get<int>(token.data), bytes, insertionIndex);
}
break;
default:
throw std::runtime_error("Unhandled value type for memory operand in InsertAsBytes");
}
}
break;
default:
throw std::runtime_error("Unhandled token type in InsertAsBytes");
}
}
bool Compiler::Compile(
std::vector<Token::Token> const & tokens,
std::vector<std::uint8_t> & bytes)
{
jumpLabelLocations.clear();
unresolvedJumpLabels.clear();
enum class State
{
FindOperand,
FindArguments,
FindStatementEnd
};
State state = State::FindOperand;
Token::OperandType operandType;
unsigned operatorTokenIndex = 0u;
int expectedNumberOfArguments = 0;
for(std::size_t i = 0u; i < tokens.size(); ++i)
{
auto const & token = tokens[i];
InsertAsBytes(token, bytes);
switch(state)
{
case State::FindOperand:
switch(token.type)
{
case Token::TokenType::Operand:
operatorTokenIndex = i;
operandType = std::get<Token::OperandType>(token.data);
expectedNumberOfArguments = GetRequiredNumberOfArguments(operandType);
if (expectedNumberOfArguments < 1)
{
state = State::FindStatementEnd;
}
else
{
state = State::FindArguments;
}
break;
case Token::TokenType::LabelDefinition:
{
auto findResult = jumpLabelLocations.find(std::get<std::string>(token.data));
if (findResult == jumpLabelLocations.end())
{
jumpLabelLocations[std::get<std::string>(token.data)] = bytes.size();
}
else
{
throw CompilationError::CreateDuplicateLabelError(token);
}
}
break;
case Token::TokenType::StatementEnd:
// NO OP
break;
default:
throw CompilationError::CreateExpectedOperandError(token);
}
break;
case State::FindArguments:
if (IsArgumentToken(token))
{
expectedNumberOfArguments -= 1;
if (expectedNumberOfArguments < 1)
{
ValidateArguments(tokens, operatorTokenIndex);
state = State::FindStatementEnd;
}
}
else
{
// TODO Further specify this error?
throw CompilationError::CreateExpectedArgumentError(token);
}
break;
case State::FindStatementEnd:
if (token.type != Token::TokenType::StatementEnd)
{
// TODO Further specify this error?
throw CompilationError::CreateExpectedEndOfStatementError(token);
}
else
{
InsertAsBytes(
token,
bytes);
state = State::FindOperand;
}
break;
}
}
for(auto const & unresolved : unresolvedJumpLabels)
{
auto const & findResult = jumpLabelLocations.find(std::get<std::string>(unresolved.first.data));
if (findResult == jumpLabelLocations.end())
{
throw CompilationError::CreateNonExistingLabelError(unresolved.first);
}
int const jumpLocation = findResult->second;
auto const index = unresolved.second;
Utils::Bytes::Write(jumpLocation, bytes, index);
}
return true;
}
}