src/lib/BasicASN1Reader.h
author František Kučera <franta-hg@frantovo.cz>
Sat, 04 Dec 2021 21:14:48 +0100
branchv_0
changeset 11 6282949e3672
parent 1 68a281aefa76
permissions -rw-r--r--
Added tag v0.18 for changeset db8429c641c6

/**
 * Relational pipes
 * Copyright © 2021 František Kučera (Frantovo.cz, GlobalCode.info)
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, version 3 of the License.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program. If not, see <http://www.gnu.org/licenses/>.
 */
#pragma once

#include <memory>
#include <vector>
#include <array>
#include <sstream>
#include <regex>

#include "ASN1Reader.h"
#include "ValidatingASN1ContentHandler.h"
#include "uri.h"

namespace relpipe {
namespace in {
namespace asn1 {
namespace lib {

/**
 * Reads ASN.1 data encoded as BER (DER, CER).
 */
class BasicASN1Reader : public ASN1Reader {
private:

	bool started = false;

	bool parseEncapsulated = true;

	/**
	 * TODO: use a common method
	 */
	bool parseBoolean(const std::string& value) {
		if (value == "true") return true;
		else if (value == "false") return false;
		else throw std::invalid_argument(std::string("Unable to parse boolean value: ") + value + " (expecting true or false)");
	}

	class BasicHeader : public ASN1ContentHandler::Header {
	public:
		bool definiteLength;
		size_t length;
	};

	class LevelMetadata {
	public:
		bool definiteLength;
		size_t length;
		size_t start;
	};

	std::vector<LevelMetadata> level;

	void checkRemainingItems() {
		if (level.size()) {
			LevelMetadata& l = level.back();
			if (l.definiteLength && l.length == getBytesRead() - l.start) {
				level.pop_back();
				handlers->writeCollectionEnd();
				checkRemainingItems(); // multiple collections may end at the same point
			}
		}
	}

	BasicHeader readHeader() {
		using TagClass = ASN1ContentHandler::TagClass;
		using PC = ASN1ContentHandler::PC;

		BasicHeader h;

		memset(&h, 0, sizeof (h)); // TODO: remove, not needed

		uint8_t tagByte;
		read(&tagByte, 1);

		h.tagClass = (TagClass) (tagByte >> 6);
		h.pc = (PC) ((tagByte >> 5) & 1);
		h.tag = tagByte & (0xFF >> 3);
		if (h.tag == 31) { // all five tag bits are set → tag number (greater than 30) is encoded in following octets
			h.tag = 0;
			uint8_t moreTag = 0;
			do {
				read(&moreTag, 1);
				h.tag = h.tag << 7 | (moreTag & (0xFF >> 1));
			} while (moreTag & (1 << 7));
		}

		uint8_t lengthByte;
		read(&lengthByte, 1);

		if (lengthByte >> 7 == 0) {
			// definite short
			h.definiteLength = true;
			h.length = lengthByte;
		} else if (lengthByte == 0b10000000) {
			// indefinite
			h.definiteLength = false;
			h.length = 0;
		} else if (lengthByte == 0xFF) {
			throw relpipe::writer::RelpipeWriterException(L"ASN.1 lengthByte == 0xFF (reserved value)"); // TODO: better exception
		} else {
			// definite long
			h.definiteLength = true;
			h.length = 0;
			std::vector<uint8_t> lengthBytes(lengthByte & 0b01111111, 0);
			read(lengthBytes.data(), lengthBytes.size());
			for (uint8_t l : lengthBytes) h.length = (h.length << 8) + l;
		}

		return h;
	}

	const std::string readString(size_t length) {
		std::string result;

		for (size_t remaining = length; remaining;) {
			size_t current = std::min(remaining, (size_t) 3);
			result.resize(result.size() + current);
			read((uint8_t*) result.data() + result.size() - current, current);
			remaining -= current;
		}

		return result;
	}

	const std::vector<uint8_t> readVector(size_t length) {
		std::vector<uint8_t> result;
		std::string s = readString(length); // TODO: read directly to the vector
		result.resize(length);
		for (size_t i = 0; i < length; i++) result[i] = (uint8_t) s[i];
		return result;
	}

	void processNext() {
		using TagClass = ASN1ContentHandler::TagClass;
		using PC = ASN1ContentHandler::PC;

		checkRemainingItems();
		BasicHeader typeHeader = readHeader();
		// commit(); // TODO: commit here and recover later instead of rollback?

		if (!started) {
			handlers->writeStreamStart();
			started = true;
		}

		// TODO: check tagClass and pc

		// TODO: constants, more types
		if (typeHeader.tag == UniversalType::EndOfContent && typeHeader.tagClass == TagClass::Universal && typeHeader.pc == PC::Primitive) {
			handlers->writeCollectionEnd();
		} else if (typeHeader.tag == UniversalType::Sequence) {
			level.push_back({typeHeader.definiteLength, typeHeader.length, getBytesRead()}); // TODO: transaction
			handlers->writeCollectionStart(typeHeader);
		} else if (typeHeader.tag == UniversalType::Set) {
			level.push_back({typeHeader.definiteLength, typeHeader.length, getBytesRead()}); // TODO: transaction
			handlers->writeCollectionStart(typeHeader);
		} else if (typeHeader.pc == PC::Constructed) {
			level.push_back({typeHeader.definiteLength, typeHeader.length, getBytesRead()}); // TODO: transaction
			handlers->writeCollectionStart(typeHeader);
		} else if (typeHeader.tag == UniversalType::Null && typeHeader.length == 0) {
			handlers->writeNull(typeHeader);
		} else if (typeHeader.tag == UniversalType::Boolean && typeHeader.definiteLength && typeHeader.length == 1) {
			bool value;
			read((uint8_t*) & value, 1);
			handlers->writeBoolean(typeHeader, value);
		} else if (typeHeader.tag == UniversalType::Integer && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			std::vector<uint8_t> value = readVector(typeHeader.length);
			handlers->writeInteger(typeHeader, ASN1ContentHandler::Integer(value));
		} else if (typeHeader.tag == UniversalType::ObjectIdentifier && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			std::vector<uint8_t> value(typeHeader.length, 0x00);
			read(value.data(), typeHeader.length);
			handlers->writeOID(typeHeader,{value});
		} else if (typeHeader.tag == UniversalType::UTF8String && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			std::string s = readString(typeHeader.length);
			handlers->writeTextString(typeHeader, s);
		} else if (typeHeader.tag == UniversalType::PrintableString && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check encoding
			std::string s = readString(typeHeader.length);
			handlers->writeTextString(typeHeader, s);
		} else if (typeHeader.tag == UniversalType::OctetString && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			std::string s = readString(typeHeader.length);
			if (processEncapsulatedContent(typeHeader, s) == false) handlers->writeOctetString(typeHeader, s);
		} else if (typeHeader.tag == UniversalType::BitString && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			std::string s = readString(typeHeader.length);
			if (processEncapsulatedContent(typeHeader, s) == false) {
				std::vector<bool> bits;
				// TODO: throw exception on wrong padding or insufficient length?
				if (s.size() > 1) {
					uint8_t padding = s[0];
					for (uint8_t j = padding; j < 8; j++) bits.push_back(s.back() & 1 << j);
					for (size_t i = s.size() - 2; i > 0; i--) for (uint8_t j = 0; j < 8; j++) bits.push_back(s[i] & 1 << j);
				}
				handlers->writeBitString(typeHeader, bits);
			}
		} else if (typeHeader.tag == UniversalType::UTCTime && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			// TODO: check encoding
			std::string s = readString(typeHeader.length);

			ASN1ContentHandler::DateTime dateTime;

			std::smatch match;
			if (std::regex_match(s, match, std::regex("([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})?(Z|([+-][0-9]{2})([0-9]{2}))"))) {
				int i = 1;
				uint32_t year = std::stoi(match[i++]);
				dateTime.year = year < 50 ? 2000 + year : 1900 + year;
				dateTime.month = std::stoi(match[i++]);
				dateTime.day = std::stoi(match[i++]);
				dateTime.hour = std::stoi(match[i++]);
				dateTime.minute = std::stoi(match[i++]);
				dateTime.precision = match[i].length() ? ASN1ContentHandler::DateTime::Precision::Second : ASN1ContentHandler::DateTime::Precision::Minute;
				dateTime.second = match[i].length() ? std::stoi(match[i]) : 0;
				i++;
				if (match[i++] != "Z") {
					dateTime.timezoneHour = std::stoi(match[i++]);
					dateTime.timezoneMinute = std::stoi(match[i++]);
				}
				handlers->writeDateTime(typeHeader, dateTime);
			} else {
				throw std::invalid_argument("Unsupported UTCTime format: " + s); // TODO: better exception
			}

		} else if (typeHeader.tag == UniversalType::GeneralizedTime && typeHeader.tagClass == TagClass::Universal && typeHeader.definiteLength) {
			std::string s = readString(typeHeader.length);

			ASN1ContentHandler::DateTime dateTime;

			std::smatch match;
			if (std::regex_match(s, match, std::regex("([0-9]{4})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})([0-9]{2})(\\.([0-9]{1,3}))?(Z|([+-][0-9]{2})([0-9]{2}))"))) {
				// TODO: support also fractions of minutes and hours in GeneralizedTime
				int i = 1;
				dateTime.year = std::stoi(match[i++]);
				dateTime.month = std::stoi(match[i++]);
				dateTime.day = std::stoi(match[i++]);
				dateTime.hour = std::stoi(match[i++]);
				dateTime.minute = std::stoi(match[i++]);
				dateTime.second = match[i].length() ? std::stoi(match[i++]) : 0;
				dateTime.precision = match[i++].length() ? ASN1ContentHandler::DateTime::Precision::Nanosecond : ASN1ContentHandler::DateTime::Precision::Second;
				if (match[i].length() == 1) dateTime.nanosecond = std::stoi(match[i++]) * 100 * 1000000;
				else if (match[i].length() == 2) dateTime.nanosecond = std::stoi(match[i++]) * 10 * 1000000;
				else if (match[i].length() == 3) dateTime.nanosecond = std::stoi(match[i++]) * 1000000;
				else i++;
				if (match[i++] != "Z") {
					dateTime.timezoneHour = std::stoi(match[i++]);
					dateTime.timezoneMinute = std::stoi(match[i++]);
				}
				handlers->writeDateTime(typeHeader, dateTime);
			} else {
				throw std::invalid_argument("Unsupported GeneralizedTime format: " + s); // TODO: better exception
			}

		} else {
			// TODO: do not skip, parse
			std::string s = readString(typeHeader.length);
			handlers->writeSpecific(typeHeader, s);
		}

		commit();
	}

	bool hasAvailableForReading() {
		// TODO: API in AbstractParser for checking available bytes?
		uint8_t tmp;
		try {
			peek(&tmp, 1);
			return true;
		} catch (...) {
			return false;
		}
	}

	bool isValidBER(const std::string& input) {
		BasicASN1Reader encapsulatedReader;
		std::shared_ptr<ValidatingASN1ContentHandler> validatingHandler = std::make_shared<ValidatingASN1ContentHandler>();
		encapsulatedReader.addHandler(validatingHandler);
		try {
			encapsulatedReader.write((const uint8_t*) input.c_str(), input.size());
			encapsulatedReader.close();
			validatingHandler->finalCheck();
			return true;
		} catch (...) {
			return false;
		}
	}

	class EncapsulatedASN1ContentHandler : public ASN1ContentHandlerProxy {
	public:

		void writeStreamStart() override {
			// skip this event
		}

		void writeStreamEnd() override {
			// skip this event
		}
	};

	/**
	 * @param typeHeader
	 * @param input OCTET STRING or BIT STRING raw bytes
	 * @return whether we found valid content and passed parsed results to handlers
	 */
	bool processEncapsulatedContent(const BasicHeader& typeHeader, const std::string& input) {
		// TODO: avoid double parsing + encapsulated content might be also processed at the XML/DOM level where we may even do conditional processing based on XPath (evaluate only certain octet- or bit- strings)
		// We may also do the same as with SEQUENCE or SET (continue nested reading in this ASN1Rreader instance), but it would require valid encapsulated data and would avoid easy fallback to raw OCTET or BIT STRING. We would also have to check the boundaries of the nested part.
		if (parseEncapsulated && isValidBER(input)) {
			handlers->writeCollectionStart(typeHeader);

			BasicASN1Reader encapsulatedReader;
			std::shared_ptr<EncapsulatedASN1ContentHandler> encapsulatedHandler = std::make_shared<EncapsulatedASN1ContentHandler>();
			encapsulatedHandler->addHandler(handlers);
			encapsulatedReader.addHandler(encapsulatedHandler);

			encapsulatedReader.write((const uint8_t*) input.c_str(), input.size());
			encapsulatedReader.close();

			handlers->writeCollectionEnd();
			return true;
		} else {
			return false;
		}
	}

protected:

	void update() override {
		while (true) processNext();
	}

public:

	bool setOption(const std::string& uri, const std::string& value) override {
		if (uri == option::Encoding && value == encoding::ber); // currently, we support only BER (and thus also CER and DER) encoding, but options have no actual effect – we just validate them
		else if (uri == option::Encoding && value == encoding::cer); // in future versions, this might switch the parser into more strict mode
		else if (uri == option::Encoding && value == encoding::der); // in future versions, this might switch the parser into more strict mode
		else if (uri == option::Encoding && value == encoding::per) throw std::invalid_argument("PER encoding is not yet supported");
		else if (uri == option::Encoding && value == encoding::xer) throw std::invalid_argument("XER encoding is not yet supported");
		else if (uri == option::Encoding && value == encoding::asn1) throw std::invalid_argument("ASN.1 encoding is not yet supported");
		else if (uri == option::Encoding) throw std::invalid_argument("Unsupported ASN.1 encoding: " + value);
		else if (uri == option::ParseEncapsulated) parseEncapsulated = parseBoolean(value);
		else return false;

		return true;
	}

	void close() override {
		if (hasAvailableForReading()) throw std::logic_error("Unexpected content at the end of the stream"); // TODO: better exception

		// TODO: check also open sequences etc.; maybe in the handler

		checkRemainingItems();
		// TODO: check the bytes remaining in the buffer
		if (started) handlers->writeStreamEnd();
	}

};

}
}
}
}