diff --git a/src/SecurityManager.cpp b/src/SecurityManager.cpp index b293825..3a00d9a 100644 --- a/src/SecurityManager.cpp +++ b/src/SecurityManager.cpp @@ -1,7 +1,15 @@ #include SecurityManager::SecurityManager(AsyncWebServer* server, FS* fs) : SettingsPersistence(fs, SECURITY_SETTINGS_FILE) { + // fetch users server->on(USERS_PATH, HTTP_GET, std::bind(&SecurityManager::fetchUsers, this, std::placeholders::_1)); + + // sign in request + _signInRequestHandler.setUri(SIGN_IN_PATH); + _signInRequestHandler.setMethod(HTTP_POST); + _signInRequestHandler.setMaxContentLength(MAX_SECURITY_MANAGER_SETTINGS_SIZE); + _signInRequestHandler.onRequest(std::bind(&SecurityManager::signIn, this, std::placeholders::_1, std::placeholders::_2)); + server->addHandler(&_signInRequestHandler); } SecurityManager::~SecurityManager() {} @@ -52,6 +60,35 @@ void SecurityManager::writeToJsonObject(JsonObject& root) { } } +// TODO - Decide about default role behaviour, don't over-engineer (multiple roles, boolean admin flag???). +void SecurityManager::signIn(AsyncWebServerRequest *request, JsonDocument &jsonDocument){ + if (jsonDocument.is()) { + // authenticate user + String username = jsonDocument["username"]; + String password = jsonDocument["password"]; + User user = authenticate(username, password); + + if (user.isAuthenticated()) { + // create JWT + DynamicJsonDocument _jsonDocument(MAX_JWT_SIZE); + JsonObject jwt = _jsonDocument.to(); + jwt["user"] = user.getUsername(); + jwt["role"] = user.getRole(); + + // send JWT response + AsyncJsonResponse * response = new AsyncJsonResponse(MAX_USERS_SIZE); + JsonObject jsonObject = response->getRoot(); + jsonObject["access_token"] = jwtHandler.encodeJWT(jwt); + response->setLength(); + request->send(response); + } + } + + // authentication failed + AsyncWebServerResponse *response = request->beginResponse(401); + request->send(response); +} + void SecurityManager::fetchUsers(AsyncWebServerRequest *request) { AsyncJsonResponse * response = new AsyncJsonResponse(MAX_USERS_SIZE); JsonObject jsonObject = response->getRoot(); @@ -61,7 +98,11 @@ void SecurityManager::fetchUsers(AsyncWebServerRequest *request) { } void SecurityManager::begin() { + // read config readFromFS(); + + // configure secret + jwtHandler.setPSK(_jwtSecret); } User SecurityManager::verifyUser(String jwt) { diff --git a/src/SecurityManager.h b/src/SecurityManager.h index 5bda393..a58e24e 100644 --- a/src/SecurityManager.h +++ b/src/SecurityManager.h @@ -2,11 +2,11 @@ #define SecurityManager_h #include -#include #include #include #include +#include #define DEFAULT_JWT_SECRET "esp8266-react" @@ -14,12 +14,15 @@ #define USERS_PATH "/rest/users" #define AUTHENTICATE_PATH "/rest/authenticate" +#define SIGN_IN_PATH "/rest/signIn" +#define MAX_JWT_SIZE 128 +#define MAX_SECURITY_MANAGER_SETTINGS_SIZE 512 #define SECURITY_MANAGER_MAX_USERS 5 -#define UNAUTHENTICATED_USERNAME "_anonymous" -#define UNAUTHENTICATED_PASSWORD "" -#define UNAUTHENTICATED_ROLE "" +#define ANONYMOUS_USERNAME "_anonymous" +#define ANONYMOUS_PASSWORD "" +#define ANONYMOUS_ROLE "" #define MAX_USERS_SIZE 1024 @@ -40,11 +43,11 @@ class User { return _role; } bool isAuthenticated() { - return _username != UNAUTHENTICATED_USERNAME; + return _username != ANONYMOUS_USERNAME; } }; -const User NOT_AUTHENTICATED = User(UNAUTHENTICATED_USERNAME, UNAUTHENTICATED_PASSWORD, UNAUTHENTICATED_ROLE); +const User NOT_AUTHENTICATED = User(ANONYMOUS_USERNAME, ANONYMOUS_PASSWORD, ANONYMOUS_ROLE); class SecurityManager : public SettingsPersistence { @@ -55,8 +58,19 @@ class SecurityManager : public SettingsPersistence { void begin(); + /* + * Lookup the user by JWT + */ User verifyUser(String jwt); + + /* + * Authenticate, returning the user if found. + */ User authenticate(String username, String password); + + /* + * Generate a JWT for the user provided + */ String generateJWT(User user); protected: @@ -65,9 +79,12 @@ class SecurityManager : public SettingsPersistence { void writeToJsonObject(JsonObject& root); private: + // jwt handler + ArduinoJsonJWT jwtHandler = ArduinoJsonJWT(DEFAULT_JWT_SECRET); // server instance AsyncWebServer* _server; + AsyncJsonRequestWebHandler _signInRequestHandler; // access point settings String _jwtSecret; @@ -76,7 +93,7 @@ class SecurityManager : public SettingsPersistence { // endpoint functions void fetchUsers(AsyncWebServerRequest *request); - + void signIn(AsyncWebServerRequest *request, JsonDocument &jsonDocument); }; #endif // end SecurityManager_h \ No newline at end of file diff --git a/src/SettingsService.h b/src/SettingsService.h index 65a771a..8adc4da 100644 --- a/src/SettingsService.h +++ b/src/SettingsService.h @@ -32,7 +32,7 @@ private: request->send(response); } - void updateConfig(AsyncWebServerRequest *request, JsonDocument &jsonDocument){ + void updateConfig(AsyncWebServerRequest *request, JsonDocument &jsonDocument) { if (jsonDocument.is()){ JsonObject newConfig = jsonDocument.as(); readFromJsonObject(newConfig); diff --git a/src/jwt/ArduinoJsonJWT.cpp b/src/jwt/ArduinoJsonJWT.cpp new file mode 100644 index 0000000..ac384b2 --- /dev/null +++ b/src/jwt/ArduinoJsonJWT.cpp @@ -0,0 +1,50 @@ +#include "jwt/ArduinoJsonJWT.h" + +ArduinoJsonJWT::ArduinoJsonJWT(String psk) : _psk(psk) { } + +void ArduinoJsonJWT::setPSK(String psk){ + _psk = psk; +} + +String ArduinoJsonJWT::encodeJWT(JsonObject payload) { + // serialize payload + String serializedPayload; + serializeJson(payload, serializedPayload); + + // calculate length of string + uint16_t encodedPayloadLength = encode_base64_length(serializedPayload.length()); + + // create JWT char array + char encodedJWT[BASE_JWT_LENGTH + encodedPayloadLength]; + unsigned char* ptr = (unsigned char*) encodedJWT; + + // 1 - add the header + memcpy(ptr, JWT_HEADER, JWT_HEADER_LENGTH); + ptr += JWT_HEADER_LENGTH; + + // 2 - add payload, trim and null terminate + *ptr++ = '.'; + encode_base64((unsigned char*) serializedPayload.c_str(), serializedPayload.length(), ptr); + ptr += encodedPayloadLength; + while(*(ptr - 1) == '=') { + ptr--; + } + *(ptr) = 0; + + // ... calculate ... + Sha256.initHmac((const unsigned char*)_psk.c_str(), _psk.length()); + Sha256.print(encodedJWT); + + // 3 - add signature + *ptr++ = '.'; + encode_base64(Sha256.resultHmac(), 32, ptr); + ptr += SIGNATURE_LENGTH; + while(*(ptr - 1) == '=') { + ptr--; + } + *(ptr) = 0; + + Serial.println(BASE_JWT_LENGTH + encodedPayloadLength); + return encodedJWT; +} + diff --git a/src/jwt/ArduinoJsonJWT.h b/src/jwt/ArduinoJsonJWT.h new file mode 100644 index 0000000..1be6c90 --- /dev/null +++ b/src/jwt/ArduinoJsonJWT.h @@ -0,0 +1,34 @@ +#ifndef ArduinoJsonJWT_H +#define ArduinoJsonJWT_H + +#include "jwt/base64.h" +#include "jwt/sha256.h" +#include "jwt/ArduinoJsonJWT.h" + +#include +#include + +class ArduinoJsonJWT { + +private: + String _psk; + + // {"alg": "HS256", "typ": "JWT"} + const char* JWT_HEADER = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"; + const uint16_t JWT_HEADER_LENGTH = strlen(JWT_HEADER); + const uint16_t SIGNATURE_LENGTH = encode_base64_length(32); + + // static JWT length is made of: + // - the header length + // - the signature length + // - 2 delimiters, 1 terminator + const uint16_t BASE_JWT_LENGTH = JWT_HEADER_LENGTH + SIGNATURE_LENGTH + 3; + +public: + ArduinoJsonJWT(String psk); + void setPSK(String psk); + String encodeJWT(JsonObject payload); +}; + + +#endif diff --git a/src/jwt/base64.cpp b/src/jwt/base64.cpp new file mode 100644 index 0000000..0ad89b1 --- /dev/null +++ b/src/jwt/base64.cpp @@ -0,0 +1,122 @@ +unsigned char binary_to_base64(unsigned char v) { + // Capital letters - 'A' is ascii 65 and base64 0 + if(v < 26) return v + 'A'; + + // Lowercase letters - 'a' is ascii 97 and base64 26 + if(v < 52) return v + 71; + + // Digits - '0' is ascii 48 and base64 52 + if(v < 62) return v - 4; + + // '+' is ascii 43 and base64 62 + if(v == 62) return '-'; + + // '/' is ascii 47 and base64 63 + if(v == 63) return '_'; + + return 64; +} + +unsigned char base64_to_binary(unsigned char c) { + // Capital letters - 'A' is ascii 65 and base64 0 + if('A' <= c && c <= 'Z') return c - 'A'; + + // Lowercase letters - 'a' is ascii 97 and base64 26 + if('a' <= c && c <= 'z') return c - 71; + + // Digits - '0' is ascii 48 and base64 52 + if('0' <= c && c <= '9') return c + 4; + + // '+' is ascii 43 and base64 62 + if(c == '-') return 62; + + // '/' is ascii 47 and base64 63 + if(c == '_') return 63; + + return 255; +} + +unsigned int encode_base64_length(unsigned int input_length) { + return (input_length + 2)/3*4; +} + +unsigned int decode_base64_length(unsigned char input[]) { + unsigned char *start = input; + + while(base64_to_binary(input[0]) < 64) { + ++input; + } + + unsigned int input_length = input - start; + + unsigned int output_length = input_length/4*3; + + switch(input_length % 4) { + default: return output_length; + case 2: return output_length + 1; + case 3: return output_length + 2; + } +} + +unsigned int encode_base64(unsigned char input[], unsigned int input_length, unsigned char output[]) { + unsigned int full_sets = input_length/3; + + // While there are still full sets of 24 bits... + for(unsigned int i = 0; i < full_sets; ++i) { + output[0] = binary_to_base64( input[0] >> 2); + output[1] = binary_to_base64((input[0] & 0x03) << 4 | input[1] >> 4); + output[2] = binary_to_base64((input[1] & 0x0F) << 2 | input[2] >> 6); + output[3] = binary_to_base64( input[2] & 0x3F); + + input += 3; + output += 4; + } + + switch(input_length % 3) { + case 0: + output[0] = '\0'; + break; + case 1: + output[0] = binary_to_base64( input[0] >> 2); + output[1] = binary_to_base64((input[0] & 0x03) << 4); + output[2] = '='; + output[3] = '='; + output[4] = '\0'; + break; + case 2: + output[0] = binary_to_base64( input[0] >> 2); + output[1] = binary_to_base64((input[0] & 0x03) << 4 | input[1] >> 4); + output[2] = binary_to_base64((input[1] & 0x0F) << 2); + output[3] = '='; + output[4] = '\0'; + break; + } + + return encode_base64_length(input_length); +} + +unsigned int decode_base64(unsigned char input[], unsigned char output[]) { + unsigned int output_length = decode_base64_length(input); + + // While there are still full sets of 24 bits... + for(unsigned int i = 2; i < output_length; i += 3) { + output[0] = base64_to_binary(input[0]) << 2 | base64_to_binary(input[1]) >> 4; + output[1] = base64_to_binary(input[1]) << 4 | base64_to_binary(input[2]) >> 2; + output[2] = base64_to_binary(input[2]) << 6 | base64_to_binary(input[3]); + + input += 4; + output += 3; + } + + switch(output_length % 3) { + case 1: + output[0] = base64_to_binary(input[0]) << 2 | base64_to_binary(input[1]) >> 4; + break; + case 2: + output[0] = base64_to_binary(input[0]) << 2 | base64_to_binary(input[1]) >> 4; + output[1] = base64_to_binary(input[1]) << 4 | base64_to_binary(input[2]) >> 2; + break; + } + + return output_length; +} \ No newline at end of file diff --git a/src/jwt/base64.h b/src/jwt/base64.h new file mode 100644 index 0000000..fd748d2 --- /dev/null +++ b/src/jwt/base64.h @@ -0,0 +1,77 @@ +/** + * Adapted from https://github.com/Densaugeo/base64_arduino + */ + +/** + * Base64 encoding and decoding of strings. Uses '+' for 62, '\' for 63, '=' for padding + * This has been modified to use '-' for 62, '_' for 63 as per the JWT specification + */ + +#ifndef BASE64_H_INCLUDED +#define BASE64_H_INCLUDED + +/* binary_to_base64: + * Description: + * Converts a single byte from a binary value to the corresponding base64 character + * Parameters: + * v - Byte to convert + * Returns: + * ascii code of base64 character. If byte is >= 64, then there is not corresponding base64 character + * and 255 is returned + */ +unsigned char binary_to_base64(unsigned char v); + +/* base64_to_binary: + * Description: + * Converts a single byte from a base64 character to the corresponding binary value + * Parameters: + * c - Base64 character (as ascii code) + * Returns: + * 6-bit binary value + */ +unsigned char base64_to_binary(unsigned char v); + +/* encode_base64_length: + * Description: + * Calculates length of base64 string needed for a given number of binary bytes + * Parameters: + * input_length - Amount of binary data in bytes + * Returns: + * Number of base64 characters needed to encode input_length bytes of binary data + */ +unsigned int encode_base64_length(unsigned int input_length); + +/* decode_base64_length: + * Description: + * Calculates number of bytes of binary data in a base64 string + * Parameters: + * input - Base64-encoded null-terminated string + * Returns: + * Number of bytes of binary data in input + */ +unsigned int decode_base64_length(unsigned char input[]); + +/* encode_base64: + * Description: + * Converts an array of bytes to a base64 null-terminated string + * Parameters: + * input - Pointer to input data + * input_length - Number of bytes to read from input pointer + * output - Pointer to output string. Null terminator will be added automatically + * Returns: + * Length of encoded string in bytes (not including null terminator) + */ +unsigned int encode_base64(unsigned char input[], unsigned int input_length, unsigned char output[]); + +/* decode_base64: + * Description: + * Converts a base64 null-terminated string to an array of bytes + * Parameters: + * input - Pointer to input string + * output - Pointer to output array + * Returns: + * Number of bytes in the decoded binary + */ +unsigned int decode_base64(unsigned char input[], unsigned char output[]); + +#endif // ifndef \ No newline at end of file diff --git a/src/jwt/sha256.cpp b/src/jwt/sha256.cpp new file mode 100644 index 0000000..7bdb2df --- /dev/null +++ b/src/jwt/sha256.cpp @@ -0,0 +1,167 @@ +#include +#include "sha256.h" + +uint32_t sha256K[] PROGMEM = { + 0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5,0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5, + 0xd807aa98,0x12835b01,0x243185be,0x550c7dc3,0x72be5d74,0x80deb1fe,0x9bdc06a7,0xc19bf174, + 0xe49b69c1,0xefbe4786,0x0fc19dc6,0x240ca1cc,0x2de92c6f,0x4a7484aa,0x5cb0a9dc,0x76f988da, + 0x983e5152,0xa831c66d,0xb00327c8,0xbf597fc7,0xc6e00bf3,0xd5a79147,0x06ca6351,0x14292967, + 0x27b70a85,0x2e1b2138,0x4d2c6dfc,0x53380d13,0x650a7354,0x766a0abb,0x81c2c92e,0x92722c85, + 0xa2bfe8a1,0xa81a664b,0xc24b8b70,0xc76c51a3,0xd192e819,0xd6990624,0xf40e3585,0x106aa070, + 0x19a4c116,0x1e376c08,0x2748774c,0x34b0bcb5,0x391c0cb3,0x4ed8aa4a,0x5b9cca4f,0x682e6ff3, + 0x748f82ee,0x78a5636f,0x84c87814,0x8cc70208,0x90befffa,0xa4506ceb,0xbef9a3f7,0xc67178f2 +}; + +#define BUFFER_SIZE 64 + +uint8_t sha256InitState[] PROGMEM = { + 0x67,0xe6,0x09,0x6a, // H0 + 0x85,0xae,0x67,0xbb, // H1 + 0x72,0xf3,0x6e,0x3c, // H2 + 0x3a,0xf5,0x4f,0xa5, // H3 + 0x7f,0x52,0x0e,0x51, // H4 + 0x8c,0x68,0x05,0x9b, // H5 + 0xab,0xd9,0x83,0x1f, // H6 + 0x19,0xcd,0xe0,0x5b // H7 +}; + +void Sha256Class::init(void) { + memcpy_P(state.b,sha256InitState,32); + byteCount = 0; + bufferOffset = 0; +} + +uint32_t Sha256Class::ror32(uint32_t number, uint8_t bits) { + return ((number << (32-bits)) | (number >> bits)); +} + +void Sha256Class::hashBlock() { + uint8_t i; + uint32_t a,b,c,d,e,f,g,h,t1,t2; + + a=state.w[0]; + b=state.w[1]; + c=state.w[2]; + d=state.w[3]; + e=state.w[4]; + f=state.w[5]; + g=state.w[6]; + h=state.w[7]; + + for (i=0; i<64; i++) { + if (i>=16) { + t1 = buffer.w[i&15] + buffer.w[(i-7)&15]; + t2 = buffer.w[(i-2)&15]; + t1 += ror32(t2,17) ^ ror32(t2,19) ^ (t2>>10); + t2 = buffer.w[(i-15)&15]; + t1 += ror32(t2,7) ^ ror32(t2,18) ^ (t2>>3); + buffer.w[i&15] = t1; + } + t1 = h; + t1 += ror32(e,6) ^ ror32(e,11) ^ ror32(e,25); // ∑1(e) + t1 += g ^ (e & (g ^ f)); // Ch(e,f,g) + t1 += pgm_read_dword(sha256K+i); // Ki + t1 += buffer.w[i&15]; // Wi + t2 = ror32(a,2) ^ ror32(a,13) ^ ror32(a,22); // ∑0(a) + t2 += ((b & c) | (a & (b | c))); // Maj(a,b,c) + h=g; g=f; f=e; e=d+t1; d=c; c=b; b=a; a=t1+t2; + } + state.w[0] += a; + state.w[1] += b; + state.w[2] += c; + state.w[3] += d; + state.w[4] += e; + state.w[5] += f; + state.w[6] += g; + state.w[7] += h; +} + +void Sha256Class::addUncounted(uint8_t data) { + buffer.b[bufferOffset ^ 3] = data; + bufferOffset++; + if (bufferOffset == BUFFER_SIZE) { + hashBlock(); + bufferOffset = 0; + } +} + +size_t Sha256Class::write(uint8_t data) { + ++byteCount; + addUncounted(data); + return 1; +} + +void Sha256Class::pad() { + // Implement SHA-256 padding (fips180-2 §5.1.1) + + // Pad with 0x80 followed by 0x00 until the end of the block + addUncounted(0x80); + while (bufferOffset != 56) addUncounted(0x00); + + // Append length in the last 8 bytes + addUncounted(0); // We're only using 32 bit lengths + addUncounted(0); // But SHA-1 supports 64 bit lengths + addUncounted(0); // So zero pad the top bits + addUncounted(byteCount >> 29); // Shifting to multiply by 8 + addUncounted(byteCount >> 21); // as SHA-1 supports bitstreams as well as + addUncounted(byteCount >> 13); // byte. + addUncounted(byteCount >> 5); + addUncounted(byteCount << 3); +} + + +uint8_t* Sha256Class::result(void) { + // Pad to complete the last block + pad(); + + // Swap byte order back + for (int i=0; i<8; i++) { + uint32_t a,b; + a=state.w[i]; + b=a<<24; + b|=(a<<8) & 0x00ff0000; + b|=(a>>8) & 0x0000ff00; + b|=a>>24; + state.w[i]=b; + } + + // Return pointer to hash (20 characters) + return state.b; +} + +#define HMAC_IPAD 0x36 +#define HMAC_OPAD 0x5c + +uint8_t keyBuffer[BLOCK_LENGTH]; // K0 in FIPS-198a +uint8_t innerHash[HASH_LENGTH]; + +void Sha256Class::initHmac(const uint8_t* key, int keyLength) { + uint8_t i; + memset(keyBuffer,0,BLOCK_LENGTH); + if (keyLength > BLOCK_LENGTH) { + // Hash long keys + init(); + for (;keyLength--;) write(*key++); + memcpy(keyBuffer,result(),HASH_LENGTH); + } else { + // Block length keys are used as is + memcpy(keyBuffer,key,keyLength); + } + // Start inner hash + init(); + for (i=0; i +#include "Print.h" + +#define HASH_LENGTH 32 +#define BLOCK_LENGTH 64 + +union _buffer { + uint8_t b[BLOCK_LENGTH]; + uint32_t w[BLOCK_LENGTH/4]; +}; +union _state { + uint8_t b[HASH_LENGTH]; + uint32_t w[HASH_LENGTH/4]; +}; + +class Sha256Class : public Print +{ + public: + void init(void); + void initHmac(const uint8_t* secret, int secretLength); + uint8_t* result(void); + uint8_t* resultHmac(void); + virtual size_t write(uint8_t); + using Print::write; + private: + void pad(); + void addUncounted(uint8_t data); + void hashBlock(); + uint32_t ror32(uint32_t number, uint8_t bits); + _buffer buffer; + uint8_t bufferOffset; + _state state; + uint32_t byteCount; + uint8_t keyBuffer[BLOCK_LENGTH]; + uint8_t innerHash[HASH_LENGTH]; +}; +extern Sha256Class Sha256; + +#endif \ No newline at end of file