[ENHANCEMENT] Enhanced Types, Improved OAuth2 Configuration and Bug Fixes (#16)

* 🐛 Refactor `getPath` to `getUrl` and add `getPath` to fix detection of routes

* 🐛 Dynamically build `RegExp` with `basePath`

*  Add `ucFirst` helper

*  Add profile and tokens typing, add `contentType` to config for token fetch and use config in `RedditOAuth2Provider` instead of `getToken` override

* ✏️ Update imports in demo app
This commit is contained in:
Dan6erbond 2021-05-24 16:30:17 +02:00 committed by GitHub
parent 35f48c0cb5
commit 47cf0f1250
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 335 additions and 89 deletions

View File

@ -1,21 +1,21 @@
import { SvelteKitAuth } from "sk-auth"; import { SvelteKitAuth } from "sk-auth";
import { import {
FacebookAuthProvider, FacebookOAuth2Provider,
GoogleOAuthProvider, GoogleOAuth2Provider,
RedditOAuthProvider, RedditOAuth2Provider,
TwitterAuthProvider, TwitterAuthProvider,
} from "sk-auth/providers"; } from "sk-auth/providers";
export const appAuth = new SvelteKitAuth({ export const appAuth = new SvelteKitAuth({
providers: [ providers: [
new GoogleOAuthProvider({ new GoogleOAuth2Provider({
clientId: import.meta.env.VITE_GOOGLE_OAUTH_CLIENT_ID, clientId: import.meta.env.VITE_GOOGLE_OAUTH_CLIENT_ID,
clientSecret: import.meta.env.VITE_GOOGLE_OAUTH_CLIENT_SECRET, clientSecret: import.meta.env.VITE_GOOGLE_OAUTH_CLIENT_SECRET,
profile(profile) { profile(profile) {
return { ...profile, provider: "google" }; return { ...profile, provider: "google" };
}, },
}), }),
new FacebookAuthProvider({ new FacebookOAuth2Provider({
clientId: import.meta.env.VITE_FACEBOOK_OAUTH_CLIENT_ID, clientId: import.meta.env.VITE_FACEBOOK_OAUTH_CLIENT_ID,
clientSecret: import.meta.env.VITE_FACEBOOK_OAUTH_CLIENT_SECRET, clientSecret: import.meta.env.VITE_FACEBOOK_OAUTH_CLIENT_SECRET,
profile(profile) { profile(profile) {
@ -29,11 +29,11 @@ export const appAuth = new SvelteKitAuth({
return { ...profile, provider: "twitter" }; return { ...profile, provider: "twitter" };
}, },
}), }),
new RedditOAuthProvider({ new RedditOAuth2Provider({
apiKey: import.meta.env.VITE_REDDIT_API_KEY, apiKey: import.meta.env.VITE_REDDIT_API_KEY,
apiSecret: import.meta.env.VITE_REDDIT_API_SECRET, apiSecret: import.meta.env.VITE_REDDIT_API_SECRET,
profile(profile) { profile(profile) {
profile = RedditOAuthProvider.profileHandler(profile); profile = RedditOAuth2Provider.profileHandler(profile);
return { ...profile, provider: "reddit" }; return { ...profile, provider: "reddit" };
}, },
}), }),
@ -45,8 +45,8 @@ export const appAuth = new SvelteKitAuth({
token = { token = {
...token, ...token,
user: { user: {
...token.user, ...(token.user ?? {}),
connections: { ...token.user.connections, [provider]: account }, connections: { ...(token.user?.connections ?? {}), [provider]: account },
}, },
}; };
} }

View File

@ -26,6 +26,10 @@ interface AuthCallbacks {
export class Auth { export class Auth {
constructor(private readonly config?: AuthConfig) {} constructor(private readonly config?: AuthConfig) {}
get basePath() {
return this.config?.basePath ?? "/api/auth";
}
getJwtSecret() { getJwtSecret() {
if (this.config?.jwtSecret) { if (this.config?.jwtSecret) {
return this.config?.jwtSecret; return this.config?.jwtSecret;
@ -68,9 +72,14 @@ export class Auth {
return this.config?.host ?? `http://${host}`; return this.config?.host ?? `http://${host}`;
} }
getPath(path: string, host?: string) { getPath(path: string) {
const uri = join([this.config?.basePath ?? "/api/auth", path]); const pathname = join([this.basePath, path]);
return new URL(uri, this.getBaseUrl(host)).pathname; return pathname;
}
getUrl(path: string, host?: string) {
const pathname = this.getPath(path);
return new URL(pathname, this.getBaseUrl(host)).href;
} }
setToken(headers: Headers, newToken: JWT | any) { setToken(headers: Headers, newToken: JWT | any) {
@ -129,7 +138,7 @@ export class Auth {
async handleEndpoint(request: ServerRequest): Promise<EndpointOutput> { async handleEndpoint(request: ServerRequest): Promise<EndpointOutput> {
const { path, headers, method, host } = request; const { path, headers, method, host } = request;
if (path === this.getPath("signout", host)) { if (path === this.getPath("signout")) {
const token = this.setToken(headers, {}); const token = this.setToken(headers, {});
const jwt = this.signToken(token); const jwt = this.signToken(token);
@ -155,7 +164,8 @@ export class Auth {
}; };
} }
const match = path.match(/\/api\/auth\/(?<method>signin|callback)\/(?<provider>\w+)/); const regex = new RegExp(join([this.basePath, `(?<method>signin|callback)/(?<provider>\\w+)`]));
const match = path.match(regex);
if (match && match.groups) { if (match && match.groups) {
const provider = this.config?.providers?.find( const provider = this.config?.providers?.find(

3
src/helpers.ts Normal file
View File

@ -0,0 +1,3 @@
export function ucFirst(val: string) {
return val.charAt(0).toUpperCase() + val.slice(1);
}

View File

@ -16,15 +16,15 @@ export abstract class Provider<T extends ProviderConfig = ProviderConfig> {
} }
getUri(svelteKitAuth: Auth, path: string, host?: string) { getUri(svelteKitAuth: Auth, path: string, host?: string) {
return `http://${host}${path}`; return svelteKitAuth.getUrl(path, host);
} }
getCallbackUri(svelteKitAuth: Auth, host?: string) { getCallbackUri(svelteKitAuth: Auth, host?: string) {
return svelteKitAuth.getPath(`${"/api/auth/callback/"}${this.id}`, host); return this.getUri(svelteKitAuth, `${"/callback/"}${this.id}`, host);
} }
getSigninUri(svelteKitAuth: Auth, host?: string) { getSigninUri(svelteKitAuth: Auth, host?: string) {
return svelteKitAuth.getPath(`${"/api/auth/signin/"}${this.id}`, host); return this.getUri(svelteKitAuth, `${"/signin/"}${this.id}`, host);
} }
abstract signin<Locals extends Record<string, any> = Record<string, any>, Body = unknown>( abstract signin<Locals extends Record<string, any> = Record<string, any>, Body = unknown>(

View File

@ -1,7 +1,25 @@
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
interface FacebookOAuth2ProviderConfig extends OAuth2ProviderConfig { export interface FacebookProfile {
userProfileFields?: string | string[]; id: string;
name: string;
first_name: string;
last_name: string;
name_format: string;
picture: { data: { height: number; is_silhouette: boolean; url: string; width: number } };
short_name: string;
email: string;
}
export interface FacebookTokens {
access_token: string;
token_type: string;
expires_in: number;
}
interface FacebookOAuth2ProviderConfig<ProfileType = FacebookProfile, TokensType = FacebookTokens>
extends OAuth2ProviderConfig<ProfileType, TokensType> {
userProfileFields?: string | (keyof FacebookProfile | string)[];
} }
const defaultConfig: Partial<FacebookOAuth2ProviderConfig> = { const defaultConfig: Partial<FacebookOAuth2ProviderConfig> = {
@ -19,19 +37,28 @@ const defaultConfig: Partial<FacebookOAuth2ProviderConfig> = {
"email", "email",
], ],
profileUrl: "https://graph.facebook.com/me", profileUrl: "https://graph.facebook.com/me",
authorizationUrl: "https://www.facebook.com/v10.0/dialog/oauth",
accessTokenUrl: "https://graph.facebook.com/v10.0/oauth/access_token",
}; };
export class FacebookOAuth2Provider extends OAuth2Provider<FacebookOAuth2ProviderConfig> { export class FacebookOAuth2Provider extends OAuth2Provider<
FacebookProfile,
FacebookTokens,
FacebookOAuth2ProviderConfig
> {
constructor(config: FacebookOAuth2ProviderConfig) { constructor(config: FacebookOAuth2ProviderConfig) {
const userProfileFields = config.userProfileFields || defaultConfig.userProfileFields; const userProfileFields = config.userProfileFields ?? defaultConfig.userProfileFields;
const profileUrl = `${config.profileUrl || defaultConfig.profileUrl}?${ const data = {
Array.isArray(userProfileFields) ? userProfileFields.join(",") : userProfileFields fields: Array.isArray(userProfileFields) ? userProfileFields.join(",") : userProfileFields!,
}`; };
const profileUrl = `${config.profileUrl ?? defaultConfig.profileUrl}?${new URLSearchParams(
data,
)}`;
super({ super({
...defaultConfig, ...defaultConfig,
profileUrl,
...config, ...config,
profileUrl,
}); });
} }
} }

View File

@ -1,15 +1,35 @@
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
const defaultConfig: Partial<OAuth2ProviderConfig> = { export interface GoogleProfile {
sub: string;
name: string;
give_name: string;
picture: string;
email: string;
email_verified: boolean;
locale: string;
}
export interface GoogleTokens {
access_token: string;
expires_in: number;
scope: string;
token_type: string;
id_token: string;
}
type GoogleOAuth2ProviderConfig = OAuth2ProviderConfig<GoogleProfile, GoogleTokens>;
const defaultConfig: Partial<GoogleOAuth2ProviderConfig> = {
id: "google", id: "google",
scope: ["openid", "profile", "email"], scope: ["openid", "profile", "email"],
accessTokenUrl: "https://accounts.google.com/o/oauth2/token", accessTokenUrl: "https://accounts.google.com/o/oauth2/token",
authorizationUrl: "https://accounts.google.com/o/oauth2/auth?response_type=code", authorizationUrl: "https://accounts.google.com/o/oauth2/auth",
profileUrl: "https://openidconnect.googleapis.com/v1/userinfo", profileUrl: "https://openidconnect.googleapis.com/v1/userinfo",
}; };
export class GoogleOAuth2Provider extends OAuth2Provider { export class GoogleOAuth2Provider extends OAuth2Provider<GoogleOAuth2ProviderConfig> {
constructor(config: OAuth2ProviderConfig) { constructor(config: GoogleOAuth2ProviderConfig) {
super({ super({
...defaultConfig, ...defaultConfig,
...config, ...config,

View File

@ -1,6 +1,11 @@
export { Provider } from "./base"; export { Provider } from "./base";
export { GoogleOAuth2Provider as GoogleOAuthProvider } from "./google"; export { GoogleOAuth2Provider } from "./google";
export type { GoogleProfile, GoogleTokens } from "./google";
export { TwitterAuthProvider } from "./twitter"; export { TwitterAuthProvider } from "./twitter";
export { FacebookOAuth2Provider as FacebookAuthProvider } from "./facebook"; export { FacebookOAuth2Provider } from "./facebook";
export { OAuth2BaseProvider as OAuth2Provider } from "./oauth2.base"; export type { FacebookProfile, FacebookTokens } from "./facebook";
export { RedditOAuth2Provider as RedditOAuthProvider } from "./reddit"; export { OAuth2BaseProvider } from "./oauth2.base";
export type { ProfileCallback } from "./oauth2.base";
export { OAuth2Provider } from "./oauth2";
export { RedditOAuth2Provider } from "./reddit";
export type { RedditProfile, RedditTokens } from "./reddit";

View File

@ -3,24 +3,41 @@ import type { Auth } from "../auth";
import type { CallbackResult } from "../types"; import type { CallbackResult } from "../types";
import { Provider, ProviderConfig } from "./base"; import { Provider, ProviderConfig } from "./base";
export interface OAuth2BaseProviderConfig extends ProviderConfig { export interface OAuth2Tokens {
profile?: (profile: any, tokens: any) => any | Promise<any>; access_token: string;
token_type: string;
} }
export abstract class OAuth2BaseProvider<T extends OAuth2BaseProviderConfig> extends Provider<T> { export type ProfileCallback<ProfileType = any, TokensType = any, ReturnType = any> = (
profile: ProfileType,
tokens: TokensType,
) => ReturnType | Promise<ReturnType>;
export interface OAuth2BaseProviderConfig<ProfileType = any, TokensType = any>
extends ProviderConfig {
profile?: ProfileCallback<ProfileType, TokensType>;
}
export abstract class OAuth2BaseProvider<
ProfileType,
TokensType extends OAuth2Tokens,
T extends OAuth2BaseProviderConfig,
> extends Provider<T> {
abstract getAuthorizationUrl( abstract getAuthorizationUrl(
request: ServerRequest, request: ServerRequest,
auth: Auth, auth: Auth,
state: string, state: string,
nonce: string,
): string | Promise<string>; ): string | Promise<string>;
abstract getTokens(code: string, redirectUri: string): any | Promise<any>; abstract getTokens(code: string, redirectUri: string): TokensType | Promise<TokensType>;
abstract getUserProfile(tokens: any): any | Promise<any>; abstract getUserProfile(tokens: any): ProfileType | Promise<ProfileType>;
async signin(request: ServerRequest, auth: Auth): Promise<EndpointOutput> { async signin(request: ServerRequest, auth: Auth): Promise<EndpointOutput> {
const { method, host, query } = request; const { method, host, query } = request;
const state = [`redirect=${query.get("redirect") ?? this.getUri(auth, host, "/")}`].join(","); const state = [`redirect=${query.get("redirect") ?? this.getUri(auth, "/", host)}`].join(",");
const base64State = Buffer.from(state).toString("base64"); const base64State = Buffer.from(state).toString("base64");
const url = await this.getAuthorizationUrl(request, auth, base64State); const nonce = Math.round(Math.random() * 1000).toString(); // TODO: Generate random based on user values
const url = await this.getAuthorizationUrl(request, auth, base64State, nonce);
if (method === "POST") { if (method === "POST") {
return { return {
@ -59,6 +76,6 @@ export abstract class OAuth2BaseProvider<T extends OAuth2BaseProviderConfig> ext
user = await this.config.profile(user, tokens); user = await this.config.profile(user, tokens);
} }
return [user, redirect ?? this.getUri(auth, host, "/")]; return [user, redirect ?? this.getUri(auth, "/", host)];
} }
} }

View File

@ -1,8 +1,10 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { Auth } from "../auth"; import type { Auth } from "../auth";
import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base"; import { ucFirst } from "../helpers";
import { OAuth2BaseProvider, OAuth2BaseProviderConfig, OAuth2Tokens } from "./oauth2.base";
export interface OAuth2ProviderConfig extends OAuth2BaseProviderConfig { export interface OAuth2ProviderConfig<ProfileType = any, TokensType extends OAuth2Tokens = any>
extends OAuth2BaseProviderConfig<ProfileType, TokensType> {
accessTokenUrl?: string; accessTokenUrl?: string;
authorizationUrl?: string; authorizationUrl?: string;
profileUrl?: string; profileUrl?: string;
@ -14,31 +16,35 @@ export interface OAuth2ProviderConfig extends OAuth2BaseProviderConfig {
params: any; params: any;
grantType?: string; grantType?: string;
responseType?: string; responseType?: string;
contentType?: "application/json" | "application/x-www-form-urlencoded";
} }
const defaultConfig: Partial<OAuth2ProviderConfig> = { const defaultConfig: Partial<OAuth2ProviderConfig> = {
responseType: "code", responseType: "code",
grantType: "authorization_code", grantType: "authorization_code",
contentType: "application/json",
}; };
export class OAuth2Provider< export class OAuth2Provider<
T extends OAuth2ProviderConfig = OAuth2ProviderConfig, ProfileType = any,
> extends OAuth2BaseProvider<T> { TokensType extends OAuth2Tokens = OAuth2Tokens,
constructor(config: T) { ConfigType extends OAuth2ProviderConfig = OAuth2ProviderConfig<ProfileType, TokensType>,
> extends OAuth2BaseProvider<ProfileType, TokensType, ConfigType> {
constructor(config: ConfigType) {
super({ super({
...defaultConfig, ...defaultConfig,
...config, ...config,
}); });
} }
getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string) { getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string, nonce: string) {
const data = { const data = {
state, state,
nonce,
response_type: this.config.responseType, response_type: this.config.responseType,
client_id: this.config.clientId, client_id: this.config.clientId,
scope: Array.isArray(this.config.scope) ? this.config.scope.join(" ") : this.config.scope, scope: Array.isArray(this.config.scope) ? this.config.scope.join(" ") : this.config.scope,
redirect_uri: this.getCallbackUri(auth, host), redirect_uri: this.getCallbackUri(auth, host),
nonce: Math.round(Math.random() * 1000).toString(), // TODO: Generate random based on user values
...(this.config.authorizationParams ?? {}), ...(this.config.authorizationParams ?? {}),
}; };
@ -46,8 +52,8 @@ export class OAuth2Provider<
return url; return url;
} }
async getTokens(code: string, redirectUri: string) { async getTokens(code: string, redirectUri: string): Promise<TokensType> {
const data = { const data: Record<string, any> = {
code, code,
grant_type: this.config.grantType, grant_type: this.config.grantType,
client_id: this.config.clientId, client_id: this.config.clientId,
@ -56,13 +62,30 @@ export class OAuth2Provider<
...(this.config.params ?? {}), ...(this.config.params ?? {}),
}; };
const res = await fetch(`${this.config.accessTokenUrl}?${new URLSearchParams(data)}`); let body: string;
if (this.config.contentType === "application/x-www-form-urlencoded") {
body = Object.entries(data)
.map(([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}`)
.join("&");
} else {
body = JSON.stringify(data);
}
const res = await fetch(this.config.accessTokenUrl!, {
body,
method: "POST",
headers: {
"Content-Type": this.config.contentType,
...(this.config.headers ?? {}),
},
});
return await res.json(); return await res.json();
} }
async getUserProfile(tokens: any) { async getUserProfile(tokens: TokensType): Promise<ProfileType> {
const res = await fetch(this.config.profileUrl!, { const res = await fetch(this.config.profileUrl!, {
headers: { Authorization: `${tokens.token_type} ${tokens.access_token}` }, headers: { Authorization: `${ucFirst(tokens.token_type)} ${tokens.access_token}` },
}); });
return await res.json(); return await res.json();
} }

View File

@ -1,12 +1,165 @@
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
import type { ProfileCallback } from "./oauth2.base";
interface RedditOAuth2ProviderConfig extends OAuth2ProviderConfig { export interface RedditProfile {
is_employee: boolean;
seen_layout_switch: boolean;
has_visited_new_profile: boolean;
pref_no_profanity: boolean;
has_external_account: boolean;
pref_geopopular: string;
seen_redesign_modal: boolean;
pref_show_trending: boolean;
subreddit: Subreddit;
pref_show_presence: boolean;
snoovatar_img: string;
snoovatar_size: number[];
gold_expiration: number;
has_gold_subscription: boolean;
is_sponsor: boolean;
num_friends: number;
features: Features;
can_edit_name: boolean;
verified: boolean;
pref_autoplay: boolean;
coins: number;
has_paypal_subscription: boolean;
has_subscribed_to_premium: boolean;
id: string;
has_stripe_subscription: boolean;
oauth_client_id: string;
can_create_subreddit: boolean;
over_18: boolean;
is_gold: boolean;
is_mod: boolean;
awarder_karma: number;
suspension_expiration_utc: null;
has_verified_email: boolean;
is_suspended: boolean;
pref_video_autoplay: boolean;
has_android_subscription: boolean;
in_redesign_beta: boolean;
icon_img: string;
pref_nightmode: boolean;
awardee_karma: number;
hide_from_robots: boolean;
password_set: boolean;
link_karma: number;
force_password_reset: boolean;
total_karma: number;
seen_give_award_tooltip: boolean;
inbox_count: number;
seen_premium_adblock_modal: boolean;
pref_top_karma_subreddits: boolean;
pref_show_snoovatar: boolean;
name: string;
pref_clickgadget: number;
created: number;
gold_creddits: number;
created_utc: number;
has_ios_subscription: boolean;
pref_show_twitter: boolean;
in_beta: boolean;
comment_karma: number;
has_subscribed: boolean;
linked_identities: string[];
seen_subreddit_chat_ftux: boolean;
}
export interface Features {
mod_service_mute_writes: boolean;
promoted_trend_blanks: boolean;
show_amp_link: boolean;
chat: boolean;
is_email_permission_required: boolean;
mod_awards: boolean;
expensive_coins_package: boolean;
mweb_xpromo_revamp_v2: MwebXpromoRevampV;
awards_on_streams: boolean;
webhook_config: boolean;
mweb_xpromo_modal_listing_click_daily_dismissible_ios: boolean;
live_orangereds: boolean;
modlog_copyright_removal: boolean;
show_nps_survey: boolean;
do_not_track: boolean;
mod_service_mute_reads: boolean;
chat_user_settings: boolean;
use_pref_account_deployment: boolean;
mweb_xpromo_interstitial_comments_ios: boolean;
chat_subreddit: boolean;
noreferrer_to_noopener: boolean;
premium_subscriptions_table: boolean;
mweb_xpromo_interstitial_comments_android: boolean;
chat_group_rollout: boolean;
resized_styles_images: boolean;
spez_modal: boolean;
mweb_xpromo_modal_listing_click_daily_dismissible_android: boolean;
mweb_xpromo_revamp_v3: MwebXpromoRevampV;
}
export interface MwebXpromoRevampV {
owner: string;
variant: string;
experiment_id: number;
}
export interface Subreddit {
default_set: boolean;
user_is_contributor: boolean;
banner_img: string;
restrict_posting: boolean;
user_is_banned: boolean;
free_form_reports: boolean;
community_icon: null;
show_media: boolean;
icon_color: string;
user_is_muted: boolean;
display_name: string;
header_img: null;
title: string;
coins: number;
previous_names: any[];
over_18: boolean;
icon_size: number[];
primary_color: string;
icon_img: string;
description: string;
submit_link_label: string;
header_size: null;
restrict_commenting: boolean;
subscribers: number;
submit_text_label: string;
is_default_icon: boolean;
link_flair_position: string;
display_name_prefixed: string;
key_color: string;
name: string;
is_default_banner: boolean;
url: string;
quarantine: boolean;
banner_size: number[];
user_is_moderator: boolean;
public_description: string;
link_flair_enabled: boolean;
disable_contributor_requests: boolean;
subreddit_type: string;
user_is_subscriber: boolean;
}
export interface RedditTokens {
access_token: string;
token_type: string;
expires_in: number;
scope: string;
}
interface RedditOAuth2ProviderConfig extends OAuth2ProviderConfig<RedditProfile, RedditTokens> {
duration?: "temporary" | "permanent"; duration?: "temporary" | "permanent";
apiKey: string; apiKey: string;
apiSecret: string; apiSecret: string;
} }
const redditProfileHandler = ({ const redditProfileHandler: ProfileCallback<RedditProfile, RedditTokens> = ({
is_employee, is_employee,
has_external_account, has_external_account,
snoovatar_img, snoovatar_img,
@ -55,48 +208,35 @@ const redditProfileHandler = ({
const defaultConfig: Partial<RedditOAuth2ProviderConfig> = { const defaultConfig: Partial<RedditOAuth2ProviderConfig> = {
id: "reddit", id: "reddit",
scope: "identity", scope: "identity",
duration: "temporary",
profile: redditProfileHandler, profile: redditProfileHandler,
authorizationUrl: "https://www.reddit.com/api/v1/authorize", authorizationUrl: "https://www.reddit.com/api/v1/authorize",
accessTokenUrl: "https://www.reddit.com/api/v1/access_token", accessTokenUrl: "https://www.reddit.com/api/v1/access_token",
profileUrl: "https://oauth.reddit.com/api/v1/me", profileUrl: "https://oauth.reddit.com/api/v1/me",
contentType: "application/x-www-form-urlencoded",
}; };
export class RedditOAuth2Provider extends OAuth2Provider<RedditOAuth2ProviderConfig> { export class RedditOAuth2Provider extends OAuth2Provider<
RedditProfile,
RedditTokens,
RedditOAuth2ProviderConfig
> {
constructor(config: RedditOAuth2ProviderConfig) { constructor(config: RedditOAuth2ProviderConfig) {
super({ super({
...defaultConfig, ...defaultConfig,
...config, ...config,
clientId: config.apiKey, clientId: config.apiKey,
clientSecret: config.apiSecret, clientSecret: config.apiSecret,
headers: {
...config.headers,
Authorization:
"Basic " + Buffer.from(`${config.apiKey}:${config.apiSecret}`).toString("base64"),
},
authorizationParams: {
...config.authorizationParams,
duration: config.duration ?? "temporary",
},
}); });
} }
static profileHandler = redditProfileHandler; static profileHandler = redditProfileHandler;
async getTokens(code: string, redirectUri: string) {
const endpoint = this.config.accessTokenUrl!;
const data = {
code,
redirect_uri: redirectUri,
grant_type: this.config.grantType!,
};
const body = Object.entries(data)
.map(([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}`)
.join("&");
const res = await fetch(endpoint, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
Authorization:
"Basic " +
Buffer.from(`${this.config.clientId}:${this.config.clientSecret}`).toString("base64"),
},
body,
});
return await res.json();
}
} }

View File

@ -1,4 +1,5 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { Auth } from "../auth";
import type { CallbackResult } from "../types"; import type { CallbackResult } from "../types";
import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base"; import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base";
@ -19,11 +20,11 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
}); });
} }
async getRequestToken(host: string) { async getRequestToken(auth: Auth, host?: string) {
const endpoint = "https://api.twitter.com/oauth/request_token"; const endpoint = "https://api.twitter.com/oauth/request_token";
const data = { const data = {
oauth_callback: encodeURIComponent(this.getCallbackUri(host)), oauth_callback: encodeURIComponent(this.getCallbackUri(auth, host)),
oauth_consumer_key: this.config.apiKey, oauth_consumer_key: this.config.apiKey,
}; };
@ -37,10 +38,10 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
}; };
} }
async getAuthorizationUrl({ host }: ServerRequest) { async getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string, nonce: string) {
const endpoint = "https://api.twitter.com/oauth/authorize"; const endpoint = "https://api.twitter.com/oauth/authorize";
const { oauthToken } = await this.getRequestToken(host); const { oauthToken } = await this.getRequestToken(auth, host);
const data = { const data = {
oauth_token: oauthToken, oauth_token: oauthToken,
@ -70,7 +71,7 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
return await res.json(); return await res.json();
} }
async callback({ query, host }: ServerRequest): Promise<CallbackResult> { async callback({ query, host }: ServerRequest, auth: Auth): Promise<CallbackResult> {
const oauthToken = query.get("oauth_token"); const oauthToken = query.get("oauth_token");
const oauthVerifier = query.get("oauth_verifier"); const oauthVerifier = query.get("oauth_verifier");
const redirect = this.getStateValue(query, "redirect"); const redirect = this.getStateValue(query, "redirect");
@ -82,6 +83,6 @@ export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderC
user = await this.config.profile(user, tokens); user = await this.config.profile(user, tokens);
} }
return [user, redirect ?? this.getUri(host, "/")]; return [user, redirect ?? this.getUri(auth, "/", host)];
} }
} }