Fork of the excellent esp8266-react - https://github.com/rjwats/esp8266-react
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

272 lines
11 KiB

  1. #ifndef WebSocketTxRx_h
  2. #define WebSocketTxRx_h
  3. #include <StatefulService.h>
  4. #include <ESPAsyncWebServer.h>
  5. #define WEB_SOCKET_CLIENT_ID_MSG_SIZE 128
  6. #define WEB_SOCKET_ORIGIN "websocket"
  7. #define WEB_SOCKET_ORIGIN_CLIENT_ID_PREFIX "websocket:"
  8. template <class T>
  9. class WebSocketConnector {
  10. protected:
  11. StatefulService<T>* _statefulService;
  12. AsyncWebServer* _server;
  13. AsyncWebSocket _webSocket;
  14. size_t _bufferSize;
  15. WebSocketConnector(StatefulService<T>* statefulService,
  16. AsyncWebServer* server,
  17. char const* webSocketPath,
  18. SecurityManager* securityManager,
  19. AuthenticationPredicate authenticationPredicate,
  20. size_t bufferSize) :
  21. _statefulService(statefulService), _server(server), _webSocket(webSocketPath), _bufferSize(bufferSize) {
  22. _webSocket.setFilter(securityManager->filterRequest(authenticationPredicate));
  23. _webSocket.onEvent(std::bind(&WebSocketConnector::onWSEvent,
  24. this,
  25. std::placeholders::_1,
  26. std::placeholders::_2,
  27. std::placeholders::_3,
  28. std::placeholders::_4,
  29. std::placeholders::_5,
  30. std::placeholders::_6));
  31. _server->addHandler(&_webSocket);
  32. _server->on(webSocketPath, HTTP_GET, std::bind(&WebSocketConnector::forbidden, this, std::placeholders::_1));
  33. }
  34. WebSocketConnector(StatefulService<T>* statefulService,
  35. AsyncWebServer* server,
  36. char const* webSocketPath,
  37. size_t bufferSize) :
  38. _statefulService(statefulService), _server(server), _webSocket(webSocketPath), _bufferSize(bufferSize) {
  39. _webSocket.onEvent(std::bind(&WebSocketConnector::onWSEvent,
  40. this,
  41. std::placeholders::_1,
  42. std::placeholders::_2,
  43. std::placeholders::_3,
  44. std::placeholders::_4,
  45. std::placeholders::_5,
  46. std::placeholders::_6));
  47. _server->addHandler(&_webSocket);
  48. }
  49. virtual void onWSEvent(AsyncWebSocket* server,
  50. AsyncWebSocketClient* client,
  51. AwsEventType type,
  52. void* arg,
  53. uint8_t* data,
  54. size_t len) = 0;
  55. String clientId(AsyncWebSocketClient* client) {
  56. return WEB_SOCKET_ORIGIN_CLIENT_ID_PREFIX + String(client->id());
  57. }
  58. private:
  59. void forbidden(AsyncWebServerRequest* request) {
  60. request->send(403);
  61. }
  62. };
  63. template <class T>
  64. class WebSocketTx : virtual public WebSocketConnector<T> {
  65. public:
  66. WebSocketTx(JsonStateReader<T> stateReader,
  67. StatefulService<T>* statefulService,
  68. AsyncWebServer* server,
  69. char const* webSocketPath,
  70. SecurityManager* securityManager,
  71. AuthenticationPredicate authenticationPredicate = AuthenticationPredicates::IS_ADMIN,
  72. size_t bufferSize = DEFAULT_BUFFER_SIZE) :
  73. WebSocketConnector<T>(statefulService,
  74. server,
  75. webSocketPath,
  76. securityManager,
  77. authenticationPredicate,
  78. bufferSize),
  79. _stateReader(stateReader) {
  80. WebSocketConnector<T>::_statefulService->addUpdateHandler(
  81. [&](const String& originId) { transmitData(nullptr, originId); }, false);
  82. }
  83. WebSocketTx(JsonStateReader<T> stateReader,
  84. StatefulService<T>* statefulService,
  85. AsyncWebServer* server,
  86. char const* webSocketPath,
  87. size_t bufferSize = DEFAULT_BUFFER_SIZE) :
  88. WebSocketConnector<T>(statefulService, server, webSocketPath, bufferSize), _stateReader(stateReader) {
  89. WebSocketConnector<T>::_statefulService->addUpdateHandler(
  90. [&](const String& originId) { transmitData(nullptr, originId); }, false);
  91. }
  92. protected:
  93. virtual void onWSEvent(AsyncWebSocket* server,
  94. AsyncWebSocketClient* client,
  95. AwsEventType type,
  96. void* arg,
  97. uint8_t* data,
  98. size_t len) {
  99. if (type == WS_EVT_CONNECT) {
  100. // when a client connects, we transmit it's id and the current payload
  101. transmitId(client);
  102. transmitData(client, WEB_SOCKET_ORIGIN);
  103. }
  104. }
  105. private:
  106. JsonStateReader<T> _stateReader;
  107. void transmitId(AsyncWebSocketClient* client) {
  108. DynamicJsonDocument jsonDocument = DynamicJsonDocument(WEB_SOCKET_CLIENT_ID_MSG_SIZE);
  109. JsonObject root = jsonDocument.to<JsonObject>();
  110. root["type"] = "id";
  111. root["id"] = WebSocketConnector<T>::clientId(client);
  112. size_t len = measureJson(jsonDocument);
  113. AsyncWebSocketMessageBuffer* buffer = WebSocketConnector<T>::_webSocket.makeBuffer(len);
  114. if (buffer) {
  115. serializeJson(jsonDocument, (char*)buffer->get(), len + 1);
  116. client->text(buffer);
  117. }
  118. }
  119. /**
  120. * Broadcasts the payload to the destination, if provided. Otherwise broadcasts to all clients except the origin, if
  121. * specified.
  122. *
  123. * Original implementation sent clients their own IDs so they could ignore updates they initiated. This approach
  124. * simplifies the client and the server implementation but may not be sufficent for all use-cases.
  125. */
  126. void transmitData(AsyncWebSocketClient* client, const String& originId) {
  127. DynamicJsonDocument jsonDocument = DynamicJsonDocument(WebSocketConnector<T>::_bufferSize);
  128. JsonObject root = jsonDocument.to<JsonObject>();
  129. root["type"] = "payload";
  130. root["origin_id"] = originId;
  131. JsonObject payload = root.createNestedObject("payload");
  132. WebSocketConnector<T>::_statefulService->read(payload, _stateReader);
  133. size_t len = measureJson(jsonDocument);
  134. AsyncWebSocketMessageBuffer* buffer = WebSocketConnector<T>::_webSocket.makeBuffer(len);
  135. if (buffer) {
  136. serializeJson(jsonDocument, (char*)buffer->get(), len + 1);
  137. if (client) {
  138. client->text(buffer);
  139. } else {
  140. WebSocketConnector<T>::_webSocket.textAll(buffer);
  141. }
  142. }
  143. }
  144. };
  145. template <class T>
  146. class WebSocketRx : virtual public WebSocketConnector<T> {
  147. public:
  148. WebSocketRx(JsonStateUpdater<T> stateUpdater,
  149. StatefulService<T>* statefulService,
  150. AsyncWebServer* server,
  151. char const* webSocketPath,
  152. SecurityManager* securityManager,
  153. AuthenticationPredicate authenticationPredicate = AuthenticationPredicates::IS_ADMIN,
  154. size_t bufferSize = DEFAULT_BUFFER_SIZE) :
  155. WebSocketConnector<T>(statefulService,
  156. server,
  157. webSocketPath,
  158. securityManager,
  159. authenticationPredicate,
  160. bufferSize),
  161. _stateUpdater(stateUpdater) {
  162. }
  163. WebSocketRx(JsonStateUpdater<T> stateUpdater,
  164. StatefulService<T>* statefulService,
  165. AsyncWebServer* server,
  166. char const* webSocketPath,
  167. size_t bufferSize = DEFAULT_BUFFER_SIZE) :
  168. WebSocketConnector<T>(statefulService, server, webSocketPath, bufferSize), _stateUpdater(stateUpdater) {
  169. }
  170. protected:
  171. virtual void onWSEvent(AsyncWebSocket* server,
  172. AsyncWebSocketClient* client,
  173. AwsEventType type,
  174. void* arg,
  175. uint8_t* data,
  176. size_t len) {
  177. if (type == WS_EVT_DATA) {
  178. AwsFrameInfo* info = (AwsFrameInfo*)arg;
  179. if (info->final && info->index == 0 && info->len == len) {
  180. if (info->opcode == WS_TEXT) {
  181. DynamicJsonDocument jsonDocument = DynamicJsonDocument(WebSocketConnector<T>::_bufferSize);
  182. DeserializationError error = deserializeJson(jsonDocument, (char*)data);
  183. if (!error && jsonDocument.is<JsonObject>()) {
  184. JsonObject jsonObject = jsonDocument.as<JsonObject>();
  185. WebSocketConnector<T>::_statefulService->update(
  186. jsonObject, _stateUpdater, WebSocketConnector<T>::clientId(client));
  187. }
  188. }
  189. }
  190. }
  191. }
  192. private:
  193. JsonStateUpdater<T> _stateUpdater;
  194. };
  195. template <class T>
  196. class WebSocketTxRx : public WebSocketTx<T>, public WebSocketRx<T> {
  197. public:
  198. WebSocketTxRx(JsonStateReader<T> stateReader,
  199. JsonStateUpdater<T> stateUpdater,
  200. StatefulService<T>* statefulService,
  201. AsyncWebServer* server,
  202. char const* webSocketPath,
  203. SecurityManager* securityManager,
  204. AuthenticationPredicate authenticationPredicate = AuthenticationPredicates::IS_ADMIN,
  205. size_t bufferSize = DEFAULT_BUFFER_SIZE) :
  206. WebSocketConnector<T>(statefulService,
  207. server,
  208. webSocketPath,
  209. securityManager,
  210. authenticationPredicate,
  211. bufferSize),
  212. WebSocketTx<T>(stateReader,
  213. statefulService,
  214. server,
  215. webSocketPath,
  216. securityManager,
  217. authenticationPredicate,
  218. bufferSize),
  219. WebSocketRx<T>(stateUpdater,
  220. statefulService,
  221. server,
  222. webSocketPath,
  223. securityManager,
  224. authenticationPredicate,
  225. bufferSize) {
  226. }
  227. WebSocketTxRx(JsonStateReader<T> stateReader,
  228. JsonStateUpdater<T> stateUpdater,
  229. StatefulService<T>* statefulService,
  230. AsyncWebServer* server,
  231. char const* webSocketPath,
  232. size_t bufferSize = DEFAULT_BUFFER_SIZE) :
  233. WebSocketConnector<T>(statefulService, server, webSocketPath, bufferSize),
  234. WebSocketTx<T>(stateReader, statefulService, server, webSocketPath, bufferSize),
  235. WebSocketRx<T>(stateUpdater, statefulService, server, webSocketPath, bufferSize) {
  236. }
  237. protected:
  238. void onWSEvent(AsyncWebSocket* server,
  239. AsyncWebSocketClient* client,
  240. AwsEventType type,
  241. void* arg,
  242. uint8_t* data,
  243. size_t len) {
  244. WebSocketRx<T>::onWSEvent(server, client, type, arg, data, len);
  245. WebSocketTx<T>::onWSEvent(server, client, type, arg, data, len);
  246. }
  247. };
  248. #endif