261 lines
7.9 KiB
Rust

mod samplers;
mod invoking_llms;
mod db;
mod config;
mod web_file_uploads;
mod web_app_state;
use axum::{
extract::{DefaultBodyLimit, Form, Query, State},
response::{Html, IntoResponse, Redirect, Response},
routing::{get, post},
Router,
};
use axum_extra::extract::cookie::{Cookie, CookieJar, SameSite};
use serde::Deserialize;
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::OnceLock;
use tera::{Tera};
use tower_http::services::{ServeDir, ServeFile};
use samplers::random_junk::random_string;
use config::{ConfigResult, GeneralServiceConfig, load_config, load_config_default};
use db::{DbResult, connect_db, init_database};
use web_file_uploads::{get_file, upload_get, upload_post};
use web_app_state::{AppState, AppStateInner, AuthenticatedUserId};
pub async fn init_app_state() -> Result<AppState, Box<dyn std::error::Error>> {
let config = load_config_default()?;
let db = connect_db(&config).await?;
let tera = Tera::new("frontend/pages/**/*.html")?;
Ok(std::sync::Arc::new(AppStateInner { config, db, tera }))
}
enum PasscodeAuthenticationResult{
WrongPassword,
Ok(AuthenticatedUserId)
}
enum WebsiteAuthenticationResult {
NoCookie,
SomeCookie(PasscodeAuthenticationResult),
}
async fn website_authentication_with_passcode(
state: &AppStateInner, passcode: &str) -> PasscodeAuthenticationResult {
let row = state.db.query_opt(
"SELECT id, name FROM person WHERE passcode = $1 LIMIT 1",
&[&passcode],
).await;
match row {
Ok(Some(row)) => PasscodeAuthenticationResult::Ok(AuthenticatedUserId {
id: row.get::<_, i32>(0),
name: row.get::<_, String>(1),
}),
Ok(None) => PasscodeAuthenticationResult::WrongPassword,
Err(err) => {
eprintln!("auth query failed: {err}");
PasscodeAuthenticationResult::WrongPassword
}
}
}
async fn website_authentication_with_cookie(
state: &AppStateInner,
jar: &CookieJar,
) -> WebsiteAuthenticationResult {
let auth_cookie = jar.get("auth");
match auth_cookie {
None => WebsiteAuthenticationResult::NoCookie,
Some(x) => {
let passcode = x.value();
WebsiteAuthenticationResult::SomeCookie(website_authentication_with_passcode(state, passcode).await)
},
}
}
fn axum_handler_with_auth<H, Fut, Res, T>(
handler: H,
state: AppState,
) -> impl Fn(CookieJar, T) -> std::pin::Pin<Box<dyn Future<Output = Response> + Send>> + Clone + Send + 'static
where
H: (Fn(T, AppState, AuthenticatedUserId) -> Fut) + Send + Clone + 'static,
Fut: Future<Output = Res> + Send + 'static,
Res: IntoResponse + 'static,
T: Send + 'static,
{
move |jar: CookieJar, args: T| {
let state = state.clone();
let handler = handler.clone();
Box::pin(async move {
let res = website_authentication_with_cookie(state.as_ref(), &jar).await;
match res {
WebsiteAuthenticationResult::SomeCookie(PasscodeAuthenticationResult::Ok(user)) => {
handler(args, state, user).await.into_response()
}
WebsiteAuthenticationResult::NoCookie => {
Redirect::to("/login").into_response()
}
WebsiteAuthenticationResult::SomeCookie(PasscodeAuthenticationResult::WrongPassword) => {
Redirect::to("/login?error=cookie").into_response()
}
}
})
}
}
#[derive(Deserialize)]
struct LoginPageQuery {
error: Option<String>,
}
#[derive(Deserialize)]
struct LoginPageForm {
passcode: String,
csrf_token: String,
}
async fn login_get(
State(state): State<AppState>,
jar: CookieJar,
Query(query): Query<LoginPageQuery>,
) -> (CookieJar, Html<String>) {
let cur_auth: WebsiteAuthenticationResult =
website_authentication_with_cookie(state.as_ref(), &jar).await;
let csrf = random_string(32);
let jar = jar.add(
Cookie::build(("csrf", csrf.clone()))
.path("/")
.http_only(true)
.same_site(SameSite::Strict)
.build(),
);
let error = match query.error.as_deref() {
Some("cookie") => Some("Incorrect session cookie"),
Some("password") => Some("Invalid passcode"),
Some("csrf") => Some("Implicit log in attempt aborted!"),
_ => None,
};
let mut ctx = tera::Context::new();
ctx.insert("csrf", &csrf);
if let Some(msg) = error {
ctx.insert("error", msg);
}
if let WebsiteAuthenticationResult::SomeCookie(PasscodeAuthenticationResult::Ok(cur_user)) = cur_auth {
ctx.insert("logged_in_cur_user", &cur_user.name);
}
let body = state.tera.render("login.html", &ctx).expect("render index");
(jar, Html(body))
}
async fn login_post(
State(state): State<AppState>,
jar: CookieJar,
Form(form): Form<LoginPageForm>
) -> impl IntoResponse {
let csrf_ok = jar
.get("csrf")
.map(|cookie| cookie.value())
== Some(form.csrf_token.as_str());
if !csrf_ok {
return Redirect::to("/login?error=csrf").into_response();
}
let res = website_authentication_with_passcode(&state, &form.passcode);
match res.await {
PasscodeAuthenticationResult::Ok(_) => {
let jar = jar.add(
Cookie::build(("auth", form.passcode))
.path("/")
.http_only(true)
.same_site(SameSite::Strict)
.build(),
);
(jar, Redirect::to("/")).into_response()
}
PasscodeAuthenticationResult::WrongPassword =>
Redirect::to("/login?error=password").into_response()
}
}
async fn index(
_: (),
state: AppState,
user: AuthenticatedUserId,
) -> impl IntoResponse {
let mut ctx = tera::Context::new();
ctx.insert("logged_in_cur_user", &user.name);
let body = state.tera
.render("index.html", &ctx)
.expect("render index");
Html(body)
}
async fn welcome(
_: (),
_state: AppState,
_user: AuthenticatedUserId,
) -> impl IntoResponse {
let body = _state.tera
.render("welcome.html", &tera::Context::new())
.expect("render welcome");
Html(body)
}
async fn board(
_: (),
_state: AppState,
_user: AuthenticatedUserId,
) -> impl IntoResponse {
let body = _state.tera
.render("board.html", &tera::Context::new())
.expect("render welcome");
Html(body)
}
pub async fn run_server() -> Result<(), Box<dyn std::error::Error>> {
let state = init_app_state().await.expect("lol");
let app = Router::new()
.route("/login", get(login_get))
.route("/login", post(login_post))
.route("/upload", get(axum_handler_with_auth(upload_get, state.clone())))
.route("/upload", post(axum_handler_with_auth(upload_post, state.clone())))
.route("/get/:id", get(axum_handler_with_auth(get_file, state.clone())))
.route("/welcome", get(axum_handler_with_auth(welcome, state.clone())), )
.route("/board", get(axum_handler_with_auth(board, state.clone())), )
.route("/", get(axum_handler_with_auth(index, state.clone())), )
.nest_service("/static", ServeDir::new("frontend/static"))
.nest_service("/pkg/frontend.js", ServeFile::new("target_pkg/frontend.js"))
.nest_service("/pkg/frontend_bg.wasm", ServeFile::new("target_pkg/frontend_bg.wasm"))
.layer(DefaultBodyLimit::disable())
.with_state(state);
let addr: SocketAddr = "127.0.0.1:3000".parse().expect("valid socket addr");
println!("listening on http://{addr}");
let listener = tokio::net::TcpListener::bind(addr)
.await
.expect("bind failed");
axum::serve(listener, app).await?;
Ok(())
}
pub async fn init_db() -> Result<(), Box<dyn std::error::Error>> {
let config = load_config_default()?;
init_database(&config).await?;
Ok(())
}