src/lib/BasicASN1Reader.h
author František Kučera <franta-hg@frantovo.cz>
Sun, 04 Jul 2021 11:51:13 +0200
branchv_0
changeset 27 d9cc2d356cdb
parent 26 e39de9b8b3a1
child 28 fade2f562970
permissions -rw-r--r--
add common Header argument to ASN1ContentHandler methods

/**
 * Relational pipes
 * Copyright © 2021 František Kučera (Frantovo.cz, GlobalCode.info)
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */
#pragma once

#include <memory>
#include <vector>
#include <array>
#include <sstream>
#include <regex>

#include "ASN1Reader.h"

namespace relpipe {
namespace in {
namespace asn1 {
namespace lib {

/**
 * Reads ASN.1 data encoded as BER (DER, CER).
 */
class BasicASN1Reader : public ASN1Reader {
private:

	bool started = false;

	class BasicHeader : public ASN1ContentHandler::Header {
	public:
		bool definiteLength;
		size_t length;
	};

	class LevelMetadata {
	public:
		bool definiteLength;
		size_t length;
		size_t start;
	};

	std::vector<LevelMetadata> level;

	void checkRemainingItems() {
		if (level.size()) {
			LevelMetadata& l = level.back();
			if (l.definiteLength && l.length == getBytesRead() - l.start) {
				level.pop_back();
				handlers.writeCollectionEnd();
				checkRemainingItems(); // multiple collections may end at the same point
			}
		}
	}

	BasicHeader readHeader() {
		using TagClass = ASN1ContentHandler::TagClass;
		using PC = ASN1ContentHandler::PC;

		BasicHeader h;

		memset(&h, 0, sizeof (h)); // TODO: remove, not needed

		uint8_t tagByte;
		read(&tagByte, 1);

		h.tagClass = (TagClass) (tagByte >> 6);
		h.pc = (PC) ((tagByte >> 5) & 1);
		h.tag = tagByte & (0xFF >> 3);
		if (h.tag == 31) { // all five tag bits are set → tag number (greater than 30) is encoded in following octets
			h.tag = 0;
			uint8_t moreTag = 0;
			do {
				read(&moreTag, 1);
				h.tag = h.tag << 7 | (moreTag & (0xFF >> 1));
			} while (moreTag & (1 << 7));
		}

		uint8_t lengthByte;
		read(&lengthByte, 1);

		if (lengthByte >> 7 == 0) {
			// definite short
			h.definiteLength = true;
			h.length = lengthByte;
		} else if (lengthByte == 0b10000000) {
			// indefinite
			h.definiteLength = false;
			h.length = 0;
		} else if (lengthByte == 0xFF) {
			throw relpipe::writer::RelpipeWriterException(L"ASN.1 lengthByte == 0xFF (reserved value)"); // TODO: better exception
		} else {
			// definite long
			h.definiteLength = true;
			h.length = 0;
			std::vector<uint8_t> lengthBytes(lengthByte & 0b01111111, 0);
			read(lengthBytes.data(), lengthBytes.size());
			for (uint8_t l : lengthBytes) h.length = (h.length << 8) + l;
		}

		return h;
	}

	void readNext() {
		using TagClass = ASN1ContentHandler::TagClass;
		using PC = ASN1ContentHandler::PC;

		checkRemainingItems();
		BasicHeader typeHeader = readHeader();
		// commit(); // TODO: commit here and recover later instead of rollback?

		if (!started) {
			handlers.writeStreamStart();
			started = true;
		}

		// TODO: check tagClass and pc

		// TODO: constants, more types
		if (typeHeader.tag == UniversalType::EndOfContent && typeHeader.tagClass == TagClass::Universal && typeHeader.pc == PC::Primitive) {
			handlers.writeCollectionEnd();
		} else if (typeHeader.tag == UniversalType::Sequence) {
			level.push_back({typeHeader.definiteLength, typeHeader.length, getBytesRead()}); // TODO: transaction
			handlers.writeCollectionStart(typeHeader);
		} else if (typeHeader.tag == UniversalType::Set) {
			level.push_back({typeHeader.definiteLength, typeHeader.length, getBytesRead()}); // TODO: transaction
			handlers.writeCollectionStart(typeHeader);
		} else if (typeHeader.pc == PC::Constructed) {
			level.push_back({typeHeader.definiteLength, typeHeader.length, getBytesRead()}); // TODO: transaction
			handlers.writeCollectionStart(typeHeader);
		} else if (typeHeader.tag == UniversalType::Null && typeHeader.length == 0) {
			handlers.writeNull(typeHeader);
		} else if (typeHeader.tag == UniversalType::Boolean && typeHeader.definiteLength && typeHeader.length == 1) {
			bool value;
			read((uint8_t*) & value, 1);
			handlers.writeBoolean(typeHeader, value);
		} else if (typeHeader.tag == UniversalType::Integer && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check available bytes before allocating buffer
			std::vector<uint8_t> value(typeHeader.length, 0x00);
			read(value.data(), typeHeader.length);
			handlers.writeInteger(typeHeader, ASN1ContentHandler::Integer(value));
		} else if (typeHeader.tag == UniversalType::ObjectIdentifier && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			std::vector<uint8_t> value(typeHeader.length, 0x00);
			read(value.data(), typeHeader.length);
			handlers.writeOID(typeHeader,{value});
		} else if (typeHeader.tag == UniversalType::UTF8String && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check available bytes before allocating buffer
			std::string s;
			s.resize(typeHeader.length);
			read((uint8_t*) s.data(), typeHeader.length);
			handlers.writeTextString(typeHeader, s);
		} else if (typeHeader.tag == UniversalType::PrintableString && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check encoding
			// TODO: check available bytes before allocating buffer
			std::string s;
			s.resize(typeHeader.length);
			read((uint8_t*) s.data(), typeHeader.length);
			handlers.writeTextString(typeHeader, s);
		} else if (typeHeader.tag == UniversalType::OctetString && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check available bytes before allocating buffer
			std::string s;
			s.resize(typeHeader.length);
			read((uint8_t*) s.data(), typeHeader.length);
			handlers.writeOctetString(typeHeader, s);
		} else if (typeHeader.tag == UniversalType::BitString && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check available bytes before allocating buffer
			std::string s;
			s.resize(typeHeader.length);
			read((uint8_t*) s.data(), typeHeader.length);
			std::vector<bool> bits;
			// TODO: throw exception on wrong padding or insufficient length?
			if (s.size() > 1) {
				uint8_t padding = s[0];
				for (uint8_t j = padding; j < 8; j++) bits.push_back(s.back() & 1 << j);
				for (size_t i = s.size() - 2; i > 0; i--) for (uint8_t j = 0; j < 8; j++) bits.push_back(s[i] & 1 << j);
			}
			handlers.writeBitString(typeHeader, bits);
		} else if (typeHeader.tag == UniversalType::UTCTime && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check available bytes before allocating buffer
			// TODO: check encoding
			std::string s;
			s.resize(typeHeader.length);
			read((uint8_t*) s.data(), typeHeader.length);

			ASN1ContentHandler::DateTime dateTime;

			std::smatch match;
			if (std::regex_match(s, match, std::regex("([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})?(Z|([+-][0-9]{2})'?([0-9]{2})'?)"))) {
				// Supported UTCTime formats:
				// YYMMDDhhmmZ
				// YYMMDDhhmmssZ
				// YYMMDDhhmm+hhmm
				// YYMMDDhhmm-hhmm
				// YYMMDDhhmmss+hhmm
				// YYMMDDhhmmss-hhmm
				int i = 1;
				uint32_t year = std::stoi(match[i++]);
				dateTime.year = year < 50 ? 2000 + year : 1900 + year;
				dateTime.month = std::stoi(match[i++]);
				dateTime.day = std::stoi(match[i++]);
				dateTime.hour = std::stoi(match[i++]);
				dateTime.minute = std::stoi(match[i++]);
				dateTime.precision = match[i].length() ? ASN1ContentHandler::DateTime::Precision::Second : ASN1ContentHandler::DateTime::Precision::Minute;
				dateTime.second = match[i].length() ? std::stoi(match[i++]) : 0;
				if (match[i++] != "Z") {
					dateTime.timezoneHour = std::stoi(match[i++]);
					dateTime.timezoneMinute = std::stoi(match[i++]);
				}
				handlers.writeDateTime(typeHeader, dateTime);
			} else {
				throw std::invalid_argument("Unsupported UTCTime format: " + s); // TODO: better exception
			}

		} else if (typeHeader.tag == UniversalType::GeneralizedTime && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check available bytes before allocating buffer
			std::string s;
			s.resize(typeHeader.length);
			read((uint8_t*) s.data(), typeHeader.length);

			ASN1ContentHandler::DateTime dateTime;

			std::smatch match;
			if (std::regex_match(s, match, std::regex("([0-9]{4})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})(\\.([0-9]{1,3}))?(Z|([+-][0-9]{2})'?([0-9]{2})'?)"))) {
				int i = 1;
				dateTime.year = std::stoi(match[i++]);
				dateTime.month = std::stoi(match[i++]);
				dateTime.day = std::stoi(match[i++]);
				dateTime.hour = std::stoi(match[i++]);
				dateTime.minute = std::stoi(match[i++]);
				dateTime.second = match[i].length() ? std::stoi(match[i++]) : 0;
				dateTime.precision = match[i++].length() ? ASN1ContentHandler::DateTime::Precision::Nanosecond : ASN1ContentHandler::DateTime::Precision::Second;
				if (match[i].length() == 1) dateTime.nanosecond = std::stoi(match[i++]) * 100 * 1000000;
				else if (match[i].length() == 2) dateTime.nanosecond = std::stoi(match[i++]) * 10 * 1000000;
				else if (match[i].length() == 3) dateTime.nanosecond = std::stoi(match[i++]) * 1000000;
				else i++;
				if (match[i++] != "Z") {
					dateTime.timezoneHour = std::stoi(match[i++]);
					dateTime.timezoneMinute = std::stoi(match[i++]);
				}
				handlers.writeDateTime(typeHeader, dateTime);
			} else {
				throw std::invalid_argument("Unsupported GeneralizedTime format: " + s); // TODO: better exception
			}

		} else {
			// TODO: do not skip, parse
			std::vector<uint8_t> temp(typeHeader.length, 0);
			read(temp.data(), typeHeader.length);
			// TODO: recover transaction?

			std::stringstream description;
			description << "value:"
					<< " tag = " << typeHeader.tag
					<< " tagClass = " << (int) typeHeader.tagClass
					<< " pc = " << (int) typeHeader.pc
					<< " length = " << typeHeader.length
					<< " definite = " << (typeHeader.definiteLength ? "true" : "false");

			handlers.writeTextString(typeHeader, description.str());
		}

		commit();
	}

protected:

	void update() override {
		while (true) readNext();
	}

public:

	void close() override {
		checkRemainingItems();
		// TODO: check the bytes remaining in the buffer
		if (started) handlers.writeStreamEnd();
	}

};

}
}
}
}