1use std::future::Future;
6
7use std::pin::Pin;
8
9use futures::{future, FutureExt};
10use http_body_util::BodyExt;
11use hyper::header;
12use jsonrpsee::{
13 core::BoxError,
14 server::{HttpBody, HttpRequest, HttpResponse},
15};
16use jsonrpsee_types::ErrorObject;
17use serde::{Deserialize, Serialize};
18use tower::Service;
19
20use super::cookie::Cookie;
21
22use base64::{engine::general_purpose::STANDARD, Engine as _};
23
24#[derive(Clone, Debug)]
56pub struct HttpRequestMiddleware<S> {
57 service: S,
58 cookie: Option<Cookie>,
59}
60
61impl<S> HttpRequestMiddleware<S> {
62 pub fn new(service: S, cookie: Option<Cookie>) -> Self {
64 Self { service, cookie }
65 }
66
67 pub fn check_credentials(&self, headers: &header::HeaderMap) -> bool {
69 self.cookie.as_ref().is_none_or(|internal_cookie| {
70 headers
71 .get(header::AUTHORIZATION)
72 .and_then(|auth_header| auth_header.to_str().ok())
73 .and_then(|auth_header| auth_header.split_whitespace().nth(1))
74 .and_then(|encoded| STANDARD.decode(encoded).ok())
75 .and_then(|decoded| String::from_utf8(decoded).ok())
76 .and_then(|request_cookie| request_cookie.split(':').nth(1).map(String::from))
77 .is_some_and(|passwd| internal_cookie.authenticate(passwd))
78 })
79 }
80
81 pub fn insert_or_replace_content_type_header(headers: &mut header::HeaderMap) {
103 if !headers.contains_key(header::CONTENT_TYPE)
104 || headers
105 .get(header::CONTENT_TYPE)
106 .filter(|value| {
107 value
108 .to_str()
109 .ok()
110 .unwrap_or_default()
111 .starts_with("text/plain")
112 })
113 .is_some()
114 {
115 headers.insert(
116 header::CONTENT_TYPE,
117 header::HeaderValue::from_static("application/json"),
118 );
119 }
120 }
121
122 async fn request_to_json_rpc_2(
124 request: HttpRequest<HttpBody>,
125 ) -> Result<(JsonRpcVersion, HttpRequest<HttpBody>), BoxError> {
126 let (parts, body) = request.into_parts();
127 let bytes = body.collect().await?.to_bytes();
128 let (version, bytes) =
129 if let Ok(request) = serde_json::from_slice::<'_, JsonRpcRequest>(bytes.as_ref()) {
130 let version = request.version();
131 if matches!(version, JsonRpcVersion::Unknown) {
132 (version, bytes)
133 } else {
134 (
135 version,
136 serde_json::to_vec(&request.into_2()).expect("valid").into(),
137 )
138 }
139 } else {
140 (JsonRpcVersion::Unknown, bytes)
141 };
142 Ok((
143 version,
144 HttpRequest::from_parts(parts, HttpBody::from(bytes.as_ref().to_vec())),
145 ))
146 }
147 async fn response_from_json_rpc_2(
149 version: JsonRpcVersion,
150 response: HttpResponse<HttpBody>,
151 ) -> Result<HttpResponse<HttpBody>, BoxError> {
152 let (parts, body) = response.into_parts();
153 let bytes = body.collect().await?.to_bytes();
154 let bytes =
155 if let Ok(response) = serde_json::from_slice::<'_, JsonRpcResponse>(bytes.as_ref()) {
156 serde_json::to_vec(&response.into_version(version))
157 .expect("valid")
158 .into()
159 } else {
160 bytes
161 };
162 Ok(HttpResponse::from_parts(
163 parts,
164 HttpBody::from(bytes.as_ref().to_vec()),
165 ))
166 }
167}
168
169#[derive(Clone)]
171pub struct HttpRequestMiddlewareLayer {
172 cookie: Option<Cookie>,
173}
174
175impl HttpRequestMiddlewareLayer {
176 pub fn new(cookie: Option<Cookie>) -> Self {
178 Self { cookie }
179 }
180}
181
182impl<S> tower::Layer<S> for HttpRequestMiddlewareLayer {
183 type Service = HttpRequestMiddleware<S>;
184
185 fn layer(&self, service: S) -> Self::Service {
186 HttpRequestMiddleware::new(service, self.cookie.clone())
187 }
188}
189
190pub trait With<T> {
192 fn with(self, _: T) -> Self;
194}
195
196impl<S> With<Cookie> for HttpRequestMiddleware<S> {
197 fn with(mut self, cookie: Cookie) -> Self {
198 self.cookie = Some(cookie);
199 self
200 }
201}
202
203impl<S> Service<HttpRequest<HttpBody>> for HttpRequestMiddleware<S>
204where
205 S: Service<HttpRequest, Response = HttpResponse> + std::clone::Clone + Send + 'static,
206 S::Error: Into<BoxError> + 'static,
207 S::Future: Send + 'static,
208{
209 type Response = S::Response;
210 type Error = BoxError;
211 type Future =
212 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
213
214 fn poll_ready(
215 &mut self,
216 cx: &mut std::task::Context<'_>,
217 ) -> std::task::Poll<Result<(), Self::Error>> {
218 self.service.poll_ready(cx).map_err(Into::into)
219 }
220
221 fn call(&mut self, mut request: HttpRequest<HttpBody>) -> Self::Future {
222 if !self.check_credentials(request.headers_mut()) {
224 let error = ErrorObject::borrowed(401, "unauthenticated method", None);
225 return future::err(BoxError::from(error)).boxed();
227 }
228
229 Self::insert_or_replace_content_type_header(request.headers_mut());
231
232 let mut service = self.service.clone();
233
234 async move {
235 let (version, request) = Self::request_to_json_rpc_2(request).await?;
236 let response = service.call(request).await.map_err(Into::into)?;
237 Self::response_from_json_rpc_2(version, response).await
238 }
239 .boxed()
240 }
241}
242
243#[derive(Clone, Copy, Debug)]
244enum JsonRpcVersion {
245 Bitcoind,
247 Lightwalletd,
250 TwoPointZero,
252 Unknown,
254}
255
256#[derive(Debug, Deserialize, Serialize)]
258struct JsonRpcRequest {
259 #[serde(skip_serializing_if = "Option::is_none")]
260 jsonrpc: Option<String>,
261 method: String,
262 #[serde(skip_serializing_if = "Option::is_none")]
263 params: Option<serde_json::Value>,
264 #[serde(skip_serializing_if = "Option::is_none")]
265 id: Option<serde_json::Value>,
266}
267
268impl JsonRpcRequest {
269 fn version(&self) -> JsonRpcVersion {
270 match (self.jsonrpc.as_deref(), &self.params, &self.id) {
271 (
272 Some("2.0"),
273 _,
274 None
275 | Some(
276 serde_json::Value::Null
277 | serde_json::Value::String(_)
278 | serde_json::Value::Number(_),
279 ),
280 ) => JsonRpcVersion::TwoPointZero,
281 (Some("1.0"), Some(_), Some(_)) => JsonRpcVersion::Lightwalletd,
282 (None, Some(_), Some(_)) => JsonRpcVersion::Bitcoind,
283 _ => JsonRpcVersion::Unknown,
284 }
285 }
286
287 fn into_2(mut self) -> Self {
288 self.jsonrpc = Some("2.0".into());
289 self
290 }
291}
292#[derive(Debug, Deserialize, Serialize)]
294struct JsonRpcResponse {
295 #[serde(skip_serializing_if = "Option::is_none")]
296 jsonrpc: Option<String>,
297 id: serde_json::Value,
298 #[serde(skip_serializing_if = "Option::is_none")]
299 result: Option<Box<serde_json::value::RawValue>>,
300 #[serde(skip_serializing_if = "Option::is_none")]
301 error: Option<serde_json::Value>,
302}
303
304impl JsonRpcResponse {
305 fn into_version(mut self, version: JsonRpcVersion) -> Self {
306 match version {
307 JsonRpcVersion::Bitcoind => {
308 self.jsonrpc = None;
309 self.result = self
310 .result
311 .or_else(|| serde_json::value::to_raw_value(&()).ok());
312 self.error = self.error.or(Some(serde_json::Value::Null));
313 }
314 JsonRpcVersion::Lightwalletd => {
315 self.jsonrpc = Some("1.0".into());
316 self.result = self
317 .result
318 .or_else(|| serde_json::value::to_raw_value(&()).ok());
319 self.error = self.error.or(Some(serde_json::Value::Null));
320 }
321 JsonRpcVersion::TwoPointZero => {
322 assert_eq!(self.jsonrpc.as_deref(), Some("2.0"));
326 if self.error.is_none() {
327 self.result = self
328 .result
329 .or_else(|| serde_json::value::to_raw_value(&()).ok());
330 } else {
331 assert!(self.result.is_none());
332 }
333 }
334 JsonRpcVersion::Unknown => (),
335 }
336 self
337 }
338}