[ENHANCEMENT] Enhanced Types, Improved OAuth2 Configuration and Bug Fixes (#16)

* 🐛 Refactor `getPath` to `getUrl` and add `getPath` to fix detection of routes

* 🐛 Dynamically build `RegExp` with `basePath`

*  Add `ucFirst` helper

*  Add profile and tokens typing, add `contentType` to config for token fetch and use config in `RedditOAuth2Provider` instead of `getToken` override

* ✏️ Update imports in demo app
This commit is contained in:
Dan6erbond 2021-05-24 16:30:17 +02:00 committed by GitHub
parent 35f48c0cb5
commit 47cf0f1250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 335 additions and 89 deletions

View File

@ -1,21 +1,21 @@
import { SvelteKitAuth } from "sk-auth";
import {
FacebookAuthProvider,
GoogleOAuthProvider,
RedditOAuthProvider,
FacebookOAuth2Provider,
GoogleOAuth2Provider,
RedditOAuth2Provider,
TwitterAuthProvider,
} from "sk-auth/providers";
export const appAuth = new SvelteKitAuth({
providers: [
new GoogleOAuthProvider({
new GoogleOAuth2Provider({
clientId: import.meta.env.VITE_GOOGLE_OAUTH_CLIENT_ID,
clientSecret: import.meta.env.VITE_GOOGLE_OAUTH_CLIENT_SECRET,
profile(profile) {
return { ...profile, provider: "google" };
},
}),
new FacebookAuthProvider({
new FacebookOAuth2Provider({
clientId: import.meta.env.VITE_FACEBOOK_OAUTH_CLIENT_ID,
clientSecret: import.meta.env.VITE_FACEBOOK_OAUTH_CLIENT_SECRET,
profile(profile) {
@ -29,11 +29,11 @@ export const appAuth = new SvelteKitAuth({
return { ...profile, provider: "twitter" };
},
}),
new RedditOAuthProvider({
new RedditOAuth2Provider({
apiKey: import.meta.env.VITE_REDDIT_API_KEY,
apiSecret: import.meta.env.VITE_REDDIT_API_SECRET,
profile(profile) {
profile = RedditOAuthProvider.profileHandler(profile);
profile = RedditOAuth2Provider.profileHandler(profile);
return { ...profile, provider: "reddit" };
},
}),
@ -45,8 +45,8 @@ export const appAuth = new SvelteKitAuth({
token = {
...token,
user: {
...token.user,
connections: { ...token.user.connections, [provider]: account },
...(token.user ?? {}),
connections: { ...(token.user?.connections ?? {}), [provider]: account },
},
};
}

View File

@ -26,6 +26,10 @@ interface AuthCallbacks {
export class Auth {
constructor(private readonly config?: AuthConfig) {}
get basePath() {
return this.config?.basePath ?? "/api/auth";
}
getJwtSecret() {
if (this.config?.jwtSecret) {
return this.config?.jwtSecret;
@ -68,9 +72,14 @@ export class Auth {
return this.config?.host ?? `http://${host}`;
}
getPath(path: string, host?: string) {
const uri = join([this.config?.basePath ?? "/api/auth", path]);
return new URL(uri, this.getBaseUrl(host)).pathname;
getPath(path: string) {
const pathname = join([this.basePath, path]);
return pathname;
}
getUrl(path: string, host?: string) {
const pathname = this.getPath(path);
return new URL(pathname, this.getBaseUrl(host)).href;
}
setToken(headers: Headers, newToken: JWT | any) {
@ -129,7 +138,7 @@ export class Auth {
async handleEndpoint(request: ServerRequest): Promise<EndpointOutput> {
const { path, headers, method, host } = request;
if (path === this.getPath("signout", host)) {
if (path === this.getPath("signout")) {
const token = this.setToken(headers, {});
const jwt = this.signToken(token);
@ -155,7 +164,8 @@ export class Auth {
};
}
const match = path.match(/\/api\/auth\/(?<method>signin|callback)\/(?<provider>\w+)/);
const regex = new RegExp(join([this.basePath, `(?<method>signin|callback)/(?<provider>\\w+)`]));
const match = path.match(regex);
if (match && match.groups) {
const provider = this.config?.providers?.find(

3
src/helpers.ts Normal file
View File

@ -0,0 +1,3 @@
export function ucFirst(val: string) {
return val.charAt(0).toUpperCase() + val.slice(1);
}

View File

@ -16,15 +16,15 @@ export abstract class Provider<T extends ProviderConfig = ProviderConfig> {
}
getUri(svelteKitAuth: Auth, path: string, host?: string) {
return `http://${host}${path}`;
return svelteKitAuth.getUrl(path, host);
}
getCallbackUri(svelteKitAuth: Auth, host?: string) {
return svelteKitAuth.getPath(`${"/api/auth/callback/"}${this.id}`, host);
return this.getUri(svelteKitAuth, `${"/callback/"}${this.id}`, host);
}
getSigninUri(svelteKitAuth: Auth, host?: string) {
return svelteKitAuth.getPath(`${"/api/auth/signin/"}${this.id}`, host);
return this.getUri(svelteKitAuth, `${"/signin/"}${this.id}`, host);
}
abstract signin<Locals extends Record<string, any> = Record<string, any>, Body = unknown>(

View File

@ -1,7 +1,25 @@
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
interface FacebookOAuth2ProviderConfig extends OAuth2ProviderConfig {
userProfileFields?: string | string[];
export interface FacebookProfile {
id: string;
name: string;
first_name: string;
last_name: string;
name_format: string;
picture: { data: { height: number; is_silhouette: boolean; url: string; width: number } };
short_name: string;
email: string;
}
export interface FacebookTokens {
access_token: string;
token_type: string;
expires_in: number;
}
interface FacebookOAuth2ProviderConfig<ProfileType = FacebookProfile, TokensType = FacebookTokens>
extends OAuth2ProviderConfig<ProfileType, TokensType> {
userProfileFields?: string | (keyof FacebookProfile | string)[];
}
const defaultConfig: Partial<FacebookOAuth2ProviderConfig> = {
@ -19,19 +37,28 @@ const defaultConfig: Partial<FacebookOAuth2ProviderConfig> = {
"email",
],
profileUrl: "https://graph.facebook.com/me",
authorizationUrl: "https://www.facebook.com/v10.0/dialog/oauth",
accessTokenUrl: "https://graph.facebook.com/v10.0/oauth/access_token",
};
export class FacebookOAuth2Provider extends OAuth2Provider<FacebookOAuth2ProviderConfig> {
export class FacebookOAuth2Provider extends OAuth2Provider<
FacebookProfile,
FacebookTokens,
FacebookOAuth2ProviderConfig
> {
constructor(config: FacebookOAuth2ProviderConfig) {
const userProfileFields = config.userProfileFields || defaultConfig.userProfileFields;
const profileUrl = `${config.profileUrl || defaultConfig.profileUrl}?${
Array.isArray(userProfileFields) ? userProfileFields.join(",") : userProfileFields
}`;
const userProfileFields = config.userProfileFields ?? defaultConfig.userProfileFields;
const data = {
fields: Array.isArray(userProfileFields) ? userProfileFields.join(",") : userProfileFields!,
};
const profileUrl = `${config.profileUrl ?? defaultConfig.profileUrl}?${new URLSearchParams(
data,
)}`;
super({
...defaultConfig,
profileUrl,
...config,
profileUrl,
});
}
}

View File

@ -1,15 +1,35 @@
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
const defaultConfig: Partial<OAuth2ProviderConfig> = {
export interface GoogleProfile {
sub: string;
name: string;
give_name: string;
picture: string;
email: string;
email_verified: boolean;
locale: string;
}
export interface GoogleTokens {
access_token: string;
expires_in: number;
scope: string;
token_type: string;
id_token: string;
}
type GoogleOAuth2ProviderConfig = OAuth2ProviderConfig<GoogleProfile, GoogleTokens>;
const defaultConfig: Partial<GoogleOAuth2ProviderConfig> = {
id: "google",
scope: ["openid", "profile", "email"],
accessTokenUrl: "https://accounts.google.com/o/oauth2/token",
authorizationUrl: "https://accounts.google.com/o/oauth2/auth?response_type=code",
authorizationUrl: "https://accounts.google.com/o/oauth2/auth",
profileUrl: "https://openidconnect.googleapis.com/v1/userinfo",
};
export class GoogleOAuth2Provider extends OAuth2Provider {
constructor(config: OAuth2ProviderConfig) {
export class GoogleOAuth2Provider extends OAuth2Provider<GoogleOAuth2ProviderConfig> {
constructor(config: GoogleOAuth2ProviderConfig) {
super({
...defaultConfig,
...config,

View File

@ -1,6 +1,11 @@
export { Provider } from "./base";
export { GoogleOAuth2Provider as GoogleOAuthProvider } from "./google";
export { GoogleOAuth2Provider } from "./google";
export type { GoogleProfile, GoogleTokens } from "./google";
export { TwitterAuthProvider } from "./twitter";
export { FacebookOAuth2Provider as FacebookAuthProvider } from "./facebook";
export { OAuth2BaseProvider as OAuth2Provider } from "./oauth2.base";
export { RedditOAuth2Provider as RedditOAuthProvider } from "./reddit";
export { FacebookOAuth2Provider } from "./facebook";
export type { FacebookProfile, FacebookTokens } from "./facebook";
export { OAuth2BaseProvider } from "./oauth2.base";
export type { ProfileCallback } from "./oauth2.base";
export { OAuth2Provider } from "./oauth2";
export { RedditOAuth2Provider } from "./reddit";
export type { RedditProfile, RedditTokens } from "./reddit";

View File

@ -3,24 +3,41 @@ import type { Auth } from "../auth";
import type { CallbackResult } from "../types";
import { Provider, ProviderConfig } from "./base";
export interface OAuth2BaseProviderConfig extends ProviderConfig {
profile?: (profile: any, tokens: any) => any | Promise<any>;
export interface OAuth2Tokens {
access_token: string;
token_type: string;
}
export abstract class OAuth2BaseProvider<T extends OAuth2BaseProviderConfig> extends Provider<T> {
export type ProfileCallback<ProfileType = any, TokensType = any, ReturnType = any> = (
profile: ProfileType,
tokens: TokensType,
) => ReturnType | Promise<ReturnType>;
export interface OAuth2BaseProviderConfig<ProfileType = any, TokensType = any>
extends ProviderConfig {
profile?: ProfileCallback<ProfileType, TokensType>;
}
export abstract class OAuth2BaseProvider<
ProfileType,
TokensType extends OAuth2Tokens,
T extends OAuth2BaseProviderConfig,
> extends Provider<T> {
abstract getAuthorizationUrl(
request: ServerRequest,
auth: Auth,
state: string,
nonce: string,
): string | Promise<string>;
abstract getTokens(code: string, redirectUri: string): any | Promise<any>;
abstract getUserProfile(tokens: any): any | Promise<any>;
abstract getTokens(code: string, redirectUri: string): TokensType | Promise<TokensType>;
abstract getUserProfile(tokens: any): ProfileType | Promise<ProfileType>;
async signin(request: ServerRequest, auth: Auth): Promise<EndpointOutput> {
const { method, host, query } = request;
const state = [`redirect=${query.get("redirect") ?? this.getUri(auth, host, "/")}`].join(",");
const state = [`redirect=${query.get("redirect") ?? this.getUri(auth, "/", host)}`].join(",");
const base64State = Buffer.from(state).toString("base64");
const url = await this.getAuthorizationUrl(request, auth, base64State);
const nonce = Math.round(Math.random() * 1000).toString(); // TODO: Generate random based on user values
const url = await this.getAuthorizationUrl(request, auth, base64State, nonce);
if (method === "POST") {
return {
@ -59,6 +76,6 @@ export abstract class OAuth2BaseProvider<T extends OAuth2BaseProviderConfig> ext
user = await this.config.profile(user, tokens);
}
return [user, redirect ?? this.getUri(auth, host, "/")];
return [user, redirect ?? this.getUri(auth, "/", host)];
}
}

View File

@ -1,8 +1,10 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { Auth } from "../auth";
import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base";
import { ucFirst } from "../helpers";
import { OAuth2BaseProvider, OAuth2BaseProviderConfig, OAuth2Tokens } from "./oauth2.base";
export interface OAuth2ProviderConfig extends OAuth2BaseProviderConfig {
export interface OAuth2ProviderConfig<ProfileType = any, TokensType extends OAuth2Tokens = any>
extends OAuth2BaseProviderConfig<ProfileType, TokensType> {
accessTokenUrl?: string;
authorizationUrl?: string;
profileUrl?: string;
@ -14,31 +16,35 @@ export interface OAuth2ProviderConfig extends OAuth2BaseProviderConfig {
params: any;
grantType?: string;
responseType?: string;
contentType?: "application/json" | "application/x-www-form-urlencoded";
}
const defaultConfig: Partial<OAuth2ProviderConfig> = {
responseType: "code",
grantType: "authorization_code",
contentType: "application/json",
};
export class OAuth2Provider<
T extends OAuth2ProviderConfig = OAuth2ProviderConfig,
> extends OAuth2BaseProvider<T> {
constructor(config: T) {
ProfileType = any,
TokensType extends OAuth2Tokens = OAuth2Tokens,
ConfigType extends OAuth2ProviderConfig = OAuth2ProviderConfig<ProfileType, TokensType>,
> extends OAuth2BaseProvider<ProfileType, TokensType, ConfigType> {
constructor(config: ConfigType) {
super({
...defaultConfig,
...config,
});
}
getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string) {
getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string, nonce: string) {
const data = {
state,
nonce,
response_type: this.config.responseType,
client_id: this.config.clientId,
scope: Array.isArray(this.config.scope) ? this.config.scope.join(" ") : this.config.scope,
redirect_uri: this.getCallbackUri(auth, host),
nonce: Math.round(Math.random() * 1000).toString(), // TODO: Generate random based on user values
...(this.config.authorizationParams ?? {}),
};
@ -46,8 +52,8 @@ export class OAuth2Provider<
return url;
}
async getTokens(code: string, redirectUri: string) {
const data = {
async getTokens(code: string, redirectUri: string): Promise<TokensType> {
const data: Record<string, any> = {
code,
grant_type: this.config.grantType,
client_id: this.config.clientId,
@ -56,13 +62,30 @@ export class OAuth2Provider<
...(this.config.params ?? {}),
};
const res = await fetch(`${this.config.accessTokenUrl}?${new URLSearchParams(data)}`);
let body: string;
if (this.config.contentType === "application/x-www-form-urlencoded") {
body = Object.entries(data)
.map(([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}`)
.join("&");
} else {
body = JSON.stringify(data);
}
const res = await fetch(this.config.accessTokenUrl!, {
body,
method: "POST",
headers: {
"Content-Type": this.config.contentType,
...(this.config.headers ?? {}),
},
});
return await res.json();
}
async getUserProfile(tokens: any) {
async getUserProfile(tokens: TokensType): Promise<ProfileType> {
const res = await fetch(this.config.profileUrl!, {
headers: { Authorization: `${tokens.token_type} ${tokens.access_token}` },
headers: { Authorization: `${ucFirst(tokens.token_type)} ${tokens.access_token}` },
});
return await res.json();
}

View File

@ -1,12 +1,165 @@
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
import type { ProfileCallback } from "./oauth2.base";
interface RedditOAuth2ProviderConfig extends OAuth2ProviderConfig {
export interface RedditProfile {
is_employee: boolean;
seen_layout_switch: boolean;
has_visited_new_profile: boolean;
pref_no_profanity: boolean;
has_external_account: boolean;
pref_geopopular: string;
seen_redesign_modal: boolean;
pref_show_trending: boolean;
subreddit: Subreddit;
pref_show_presence: boolean;
snoovatar_img: string;
snoovatar_size: number[];
gold_expiration: number;
has_gold_subscription: boolean;
is_sponsor: boolean;
num_friends: number;
features: Features;
can_edit_name: boolean;
verified: boolean;
pref_autoplay: boolean;
coins: number;
has_paypal_subscription: boolean;
has_subscribed_to_premium: boolean;
id: string;
has_stripe_subscription: boolean;
oauth_client_id: string;
can_create_subreddit: boolean;
over_18: boolean;
is_gold: boolean;
is_mod: boolean;
awarder_karma: number;
suspension_expiration_utc: null;
has_verified_email: boolean;
is_suspended: boolean;
pref_video_autoplay: boolean;
has_android_subscription: boolean;
in_redesign_beta: boolean;
icon_img: string;
pref_nightmode: boolean;
awardee_karma: number;
hide_from_robots: boolean;
password_set: boolean;
link_karma: number;
force_password_reset: boolean;
total_karma: number;
seen_give_award_tooltip: boolean;
inbox_count: number;
seen_premium_adblock_modal: boolean;
pref_top_karma_subreddits: boolean;
pref_show_snoovatar: boolean;
name: string;
pref_clickgadget: number;
created: number;
gold_creddits: number;
created_utc: number;
has_ios_subscription: boolean;
pref_show_twitter: boolean;
in_beta: boolean;
comment_karma: number;
has_subscribed: boolean;
linked_identities: string[];
seen_subreddit_chat_ftux: boolean;
}
export interface Features {
mod_service_mute_writes: boolean;
promoted_trend_blanks: boolean;
show_amp_link: boolean;
chat: boolean;
is_email_permission_required: boolean;
mod_awards: boolean;
expensive_coins_package: boolean;
mweb_xpromo_revamp_v2: MwebXpromoRevampV;
awards_on_streams: boolean;
webhook_config: boolean;
mweb_xpromo_modal_listing_click_daily_dismissible_ios: boolean;
live_orangereds: boolean;
modlog_copyright_removal: boolean;
show_nps_survey: boolean;
do_not_track: boolean;
mod_service_mute_reads: boolean;
chat_user_settings: boolean;
use_pref_account_deployment: boolean;
mweb_xpromo_interstitial_comments_ios: boolean;
chat_subreddit: boolean;
noreferrer_to_noopener: boolean;
premium_subscriptions_table: boolean;
mweb_xpromo_interstitial_comments_android: boolean;
chat_group_rollout: boolean;
resized_styles_images: boolean;
spez_modal: boolean;
mweb_xpromo_modal_listing_click_daily_dismissible_android: boolean;
mweb_xpromo_revamp_v3: MwebXpromoRevampV;
}
export interface MwebXpromoRevampV {
owner: string;
variant: string;
experiment_id: number;
}
export interface Subreddit {
default_set: boolean;
user_is_contributor: boolean;
banner_img: string;
restrict_posting: boolean;
user_is_banned: boolean;
free_form_reports: boolean;
community_icon: null;
show_media: boolean;
icon_color: string;
user_is_muted: boolean;
display_name: string;
header_img: null;
title: string;
coins: number;
previous_names: any[];
over_18: boolean;
icon_size: number[];
primary_color: string;
icon_img: string;
description: string;
submit_link_label: string;
header_size: null;
restrict_commenting: boolean;
subscribers: number;
submit_text_label: string;
is_default_icon: boolean;
link_flair_position: string;
display_name_prefixed: string;
key_color: string;
name: string;
is_default_banner: boolean;
url: string;
quarantine: boolean;
banner_size: number[];
user_is_moderator: boolean;
public_description: string;
link_flair_enabled: boolean;
disable_contributor_requests: boolean;
subreddit_type: string;
user_is_subscriber: boolean;
}
export interface RedditTokens {
access_token: string;
token_type: string;
expires_in: number;
scope: string;
}
interface RedditOAuth2ProviderConfig extends OAuth2ProviderConfig<RedditProfile, RedditTokens> {
duration?: "temporary" | "permanent";
apiKey: string;
apiSecret: string;
}
const redditProfileHandler = ({
const redditProfileHandler: ProfileCallback<RedditProfile, RedditTokens> = ({
is_employee,
has_external_account,
snoovatar_img,
@ -55,48 +208,35 @@ const redditProfileHandler = ({
const defaultConfig: Partial<RedditOAuth2ProviderConfig> = {
id: "reddit",
scope: "identity",
duration: "temporary",
profile: redditProfileHandler,
authorizationUrl: "https://www.reddit.com/api/v1/authorize",
accessTokenUrl: "https://www.reddit.com/api/v1/access_token",
profileUrl: "https://oauth.reddit.com/api/v1/me",
contentType: "application/x-www-form-urlencoded",
};
export class RedditOAuth2Provider extends OAuth2Provider<RedditOAuth2ProviderConfig> {
export class RedditOAuth2Provider extends OAuth2Provider<
RedditProfile,
RedditTokens,
RedditOAuth2ProviderConfig
> {
constructor(config: RedditOAuth2ProviderConfig) {
super({
...defaultConfig,
...config,
clientId: config.apiKey,
clientSecret: config.apiSecret,
headers: {
...config.headers,
Authorization:
"Basic " + Buffer.from(`${config.apiKey}:${config.apiSecret}`).toString("base64"),
},
authorizationParams: {
...config.authorizationParams,
duration: config.duration ?? "temporary",
},
});
}
static profileHandler = redditProfileHandler;
async getTokens(code: string, redirectUri: string) {
const endpoint = this.config.accessTokenUrl!;
const data = {
code,
redirect_uri: redirectUri,
grant_type: this.config.grantType!,
};
const body = Object.entries(data)
.map(([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}`)
.join("&");
const res = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization:
"Basic " +
Buffer.from(`${this.config.clientId}:${this.config.clientSecret}`).toString("base64"),
},
body,
});
return await res.json();
}
}

View File

@ -1,4 +1,5 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { Auth } from "../auth";
import type { CallbackResult } from "../types";
import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base";
@ -19,11 +20,11 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
});
}
async getRequestToken(host: string) {
async getRequestToken(auth: Auth, host?: string) {
const endpoint = "https://api.twitter.com/oauth/request_token";
const data = {
oauth_callback: encodeURIComponent(this.getCallbackUri(host)),
oauth_callback: encodeURIComponent(this.getCallbackUri(auth, host)),
oauth_consumer_key: this.config.apiKey,
};
@ -37,10 +38,10 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
};
}
async getAuthorizationUrl({ host }: ServerRequest) {
async getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string, nonce: string) {
const endpoint = "https://api.twitter.com/oauth/authorize";
const { oauthToken } = await this.getRequestToken(host);
const { oauthToken } = await this.getRequestToken(auth, host);
const data = {
oauth_token: oauthToken,
@ -70,7 +71,7 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
return await res.json();
}
async callback({ query, host }: ServerRequest): Promise<CallbackResult> {
async callback({ query, host }: ServerRequest, auth: Auth): Promise<CallbackResult> {
const oauthToken = query.get("oauth_token");
const oauthVerifier = query.get("oauth_verifier");
const redirect = this.getStateValue(query, "redirect");
@ -82,6 +83,6 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
user = await this.config.profile(user, tokens);
}
return [user, redirect ?? this.getUri(host, "/")];
return [user, redirect ?? this.getUri(auth, "/", host)];
}
}