[ENHANCEMENT] OAuth Base Provider (#12)

*  Inject auth instance into provider `signin()` and `callback()` methods

Add generic OAuth provider to implement with simple config.

* 🐛 Fix storing multiple social connections in demo app

*  Create `apiKey` and `apiSecret` aliases for Reddit provider

* ⬆️ Reinstall local dep

* 🏷️ Remove comments / use `OAuth2ProviderConfig` for `GoogleOAuth2Provider` types
This commit is contained in:
Dan6erbond 2021-05-23 22:09:57 +02:00 committed by GitHub
parent 6397de8a45
commit b4f7688377
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 189 additions and 232 deletions

View File

@ -46,7 +46,7 @@ export const appAuth = new SvelteKitAuth({
...token, ...token,
user: { user: {
...token.user, ...token.user,
connections: { [provider]: account }, connections: { ...token.user.connections, [provider]: account },
}, },
}; };
} }

View File

@ -2095,6 +2095,12 @@ simple-swizzle@^0.2.2:
dependencies: dependencies:
is-arrayish "^0.3.1" is-arrayish "^0.3.1"
"sk-auth@file:..":
version "0.1.1"
dependencies:
cookie "^0.4.1"
jsonwebtoken "^8.5.1"
"sk-auth@file:../": "sk-auth@file:../":
version "0.1.1" version "0.1.1"
dependencies: dependencies:

View File

@ -105,7 +105,7 @@ export class Auth {
provider: Provider, provider: Provider,
): Promise<EndpointOutput> { ): Promise<EndpointOutput> {
const { headers, host } = request; const { headers, host } = request;
const [profile, redirectUrl] = await provider.callback(request); const [profile, redirectUrl] = await provider.callback(request, this);
let token = (await this.getToken(headers)) ?? { user: {} }; let token = (await this.getToken(headers)) ?? { user: {} };
if (this.config?.callbacks?.jwt) { if (this.config?.callbacks?.jwt) {
@ -129,7 +129,7 @@ export class Auth {
async handleEndpoint(request: ServerRequest): Promise<EndpointOutput> { async handleEndpoint(request: ServerRequest): Promise<EndpointOutput> {
const { path, headers, method, host } = request; const { path, headers, method, host } = request;
if (path === this.getPath("signout")) { if (path === this.getPath("signout", host)) {
const token = this.setToken(headers, {}); const token = this.setToken(headers, {});
const jwt = this.signToken(token); const jwt = this.signToken(token);
@ -163,7 +163,7 @@ export class Auth {
); );
if (provider) { if (provider) {
if (match.groups.method === "signin") { if (match.groups.method === "signin") {
return await provider.signin(request); return await provider.signin(request, this);
} else { } else {
return await this.handleProviderCallback(request, provider); return await this.handleProviderCallback(request, provider);
} }

View File

@ -1,5 +1,6 @@
import type { EndpointOutput } from "@sveltejs/kit"; import type { EndpointOutput } from "@sveltejs/kit";
import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { Auth } from "../auth";
import type { CallbackResult } from "../types"; import type { CallbackResult } from "../types";
export interface ProviderConfig { export interface ProviderConfig {
@ -14,19 +15,25 @@ export abstract class Provider<T extends ProviderConfig = ProviderConfig> {
this.id = config.id!; this.id = config.id!;
} }
getUri(host: string, path: string) { getUri(svelteKitAuth: Auth, path: string, host?: string) {
return `http://${host}${path}`; return `http://${host}${path}`;
} }
getCallbackUri(host: string) { getCallbackUri(svelteKitAuth: Auth, host?: string) {
return this.getUri(host, `${"/api/auth/callback/"}${this.id}`); return svelteKitAuth.getPath(`${"/api/auth/callback/"}${this.id}`, host);
}
getSigninUri(svelteKitAuth: Auth, host?: string) {
return svelteKitAuth.getPath(`${"/api/auth/signin/"}${this.id}`, host);
} }
abstract signin<Locals extends Record<string, any> = Record<string, any>, Body = unknown>( abstract signin<Locals extends Record<string, any> = Record<string, any>, Body = unknown>(
request: ServerRequest<Locals, Body>, request: ServerRequest<Locals, Body>,
svelteKitAuth: Auth,
): EndpointOutput | Promise<EndpointOutput>; ): EndpointOutput | Promise<EndpointOutput>;
abstract callback<Locals extends Record<string, any> = Record<string, any>, Body = unknown>( abstract callback<Locals extends Record<string, any> = Record<string, any>, Body = unknown>(
request: ServerRequest<Locals, Body>, request: ServerRequest<Locals, Body>,
svelteKitAuth: Auth,
): CallbackResult | Promise<CallbackResult>; ): CallbackResult | Promise<CallbackResult>;
} }

View File

@ -1,78 +1,37 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
interface FacebookAuthProviderConfig extends OAuth2ProviderConfig { interface FacebookOAuth2ProviderConfig extends OAuth2ProviderConfig {
clientId: string; userProfileFields?: string | string[];
clientSecret: string;
userProfileFields?: string;
scope?: string;
} }
const defaultConfig: Partial<FacebookAuthProviderConfig> = { const defaultConfig: Partial<FacebookOAuth2ProviderConfig> = {
id: "facebook", id: "facebook",
scope: "email public_profile user_link", scope: ["email", "public_profile", "user_link"],
userProfileFields: userProfileFields: [
"id,name,first_name,last_name,middle_name,name_format,picture,short_name,email", "id",
"name",
"first_name",
"last_name",
"middle_name",
"name_format",
"picture",
"short_name",
"email",
],
profileUrl: "https://graph.facebook.com/me",
}; };
export class FacebookAuthProvider extends OAuth2Provider<FacebookAuthProviderConfig> { export class FacebookOAuth2Provider extends OAuth2Provider<FacebookOAuth2ProviderConfig> {
constructor(config: FacebookAuthProviderConfig) { constructor(config: FacebookOAuth2ProviderConfig) {
const userProfileFields = config.userProfileFields || defaultConfig.userProfileFields;
const profileUrl = `${config.profileUrl || defaultConfig.profileUrl}?${
Array.isArray(userProfileFields) ? userProfileFields.join(",") : userProfileFields
}`;
super({ super({
...defaultConfig, ...defaultConfig,
profileUrl,
...config, ...config,
}); });
} }
getSigninUrl({ host }: ServerRequest, state: string) {
const endpoint = "https://www.facebook.com/v10.0/dialog/oauth";
const data = {
client_id: this.config.clientId,
scope: this.config.scope!,
redirect_uri: this.getCallbackUri(host),
state,
};
const url = `${endpoint}?${new URLSearchParams(data)}`;
return url;
}
async getTokens(code: string, redirectUri: string) {
const endpoint = "https://graph.facebook.com/v10.0/oauth/access_token";
const data = {
code,
client_id: this.config.clientId,
redirect_uri: redirectUri,
client_secret: this.config.clientSecret,
};
const res = await fetch(`${endpoint}?${new URLSearchParams(data)}`);
return await res.json();
}
async inspectToken(tokens: any) {
const endpoint = "https://graph.facebook.com/debug_token";
const data = {
input_token: tokens.access_token,
access_token: `${this.config.clientId}|${this.config.clientSecret}`,
};
const res = await fetch(`${endpoint}?${new URLSearchParams(data)}`);
return await res.json();
}
async getUserProfile(tokens: any) {
const inspectResult = await this.inspectToken(tokens);
const endpoint = `https://graph.facebook.com/v10.0/${inspectResult.data.user_id}`;
const data = {
access_token: tokens.access_token,
fields: this.config.userProfileFields!,
};
const res = await fetch(`${endpoint}?${new URLSearchParams(data)}`);
return await res.json();
}
} }

View File

@ -1,82 +1,18 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
interface GoogleOAuthProviderConfig extends OAuth2ProviderConfig { const defaultConfig: Partial<OAuth2ProviderConfig> = {
clientId: string;
clientSecret: string;
discoveryDocument?: string;
scope?: string;
}
const defaultConfig: Partial<GoogleOAuthProviderConfig> = {
id: "google", id: "google",
discoveryDocument: "https://accounts.google.com/.well-known/openid-configuration", scope: ["openid", "profile", "email"],
scope: "openid profile email", accessTokenUrl: "https://accounts.google.com/o/oauth2/token",
authorizationUrl: "https://accounts.google.com/o/oauth2/auth?response_type=code",
profileUrl: "https://openidconnect.googleapis.com/v1/userinfo",
}; };
export class GoogleOAuthProvider extends OAuth2Provider<GoogleOAuthProviderConfig> { export class GoogleOAuth2Provider extends OAuth2Provider {
constructor(config: GoogleOAuthProviderConfig) { constructor(config: OAuth2ProviderConfig) {
super({ super({
...defaultConfig, ...defaultConfig,
...config, ...config,
}); });
} }
async getProviderMetadata() {
const res = await fetch(this.config.discoveryDocument!);
const metadata = await res.json();
return metadata;
}
async getEndpoint(key: string) {
const metadata = await this.getProviderMetadata();
return metadata[key] as string;
}
async getSigninUrl({ host }: ServerRequest, state: string) {
const authorizationEndpoint = await this.getEndpoint("authorization_endpoint");
const data = {
response_type: "code",
client_id: this.config.clientId,
scope: this.config.scope!,
redirect_uri: this.getCallbackUri(host),
state,
login_hint: "example@provider.com",
nonce: Math.round(Math.random() * 1000).toString(), // TODO: Generate random based on user values
};
const url = `${authorizationEndpoint}?${new URLSearchParams(data)}`;
return url;
}
async getTokens(code: string, redirectUri: string) {
const tokenEndpoint = await this.getEndpoint("token_endpoint");
const data = {
code,
client_id: this.config.clientId,
client_secret: this.config.clientSecret,
redirect_uri: redirectUri,
grant_type: "authorization_code",
};
const res = await fetch(tokenEndpoint, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(data),
});
return await res.json();
}
async getUserProfile(tokens: any) {
const userProfileEndpoint = await this.getEndpoint("userinfo_endpoint");
const res = await fetch(userProfileEndpoint, {
headers: { Authorization: `${tokens.token_type} ${tokens.access_token}` },
});
return await res.json();
}
} }

View File

@ -1,6 +1,6 @@
export { Provider } from "./base"; export { Provider } from "./base";
export { GoogleOAuthProvider } from "./google"; export { GoogleOAuth2Provider as GoogleOAuthProvider } from "./google";
export { TwitterAuthProvider } from "./twitter"; export { TwitterAuthProvider } from "./twitter";
export { FacebookAuthProvider } from "./facebook"; export { FacebookOAuth2Provider as FacebookAuthProvider } from "./facebook";
export { OAuth2Provider } from "./oauth2"; export { OAuth2BaseProvider as OAuth2Provider } from "./oauth2.base";
export { RedditOAuthProvider } from "./reddit"; export { RedditOAuth2Provider as RedditOAuthProvider } from "./reddit";

View File

@ -0,0 +1,60 @@
import type { EndpointOutput, ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { Auth } from "../auth";
import type { CallbackResult } from "../types";
import { Provider, ProviderConfig } from "./base";
export interface OAuth2BaseProviderConfig extends ProviderConfig {
profile?: (profile: any, tokens: any) => any | Promise<any>;
}
export abstract class OAuth2BaseProvider<T extends OAuth2BaseProviderConfig> extends Provider<T> {
abstract getAuthorizationUrl(request: ServerRequest, auth: Auth, state: string): string | Promise<string>;
abstract getTokens(code: string, redirectUri: string): any | Promise<any>;
abstract getUserProfile(tokens: any): any | Promise<any>;
async signin(request: ServerRequest, auth: Auth): Promise<EndpointOutput> {
const { method, host, query } = request;
const state = [`redirect=${query.get("redirect") ?? this.getUri(auth, host, "/")}`].join(",");
const base64State = Buffer.from(state).toString("base64");
const url = await this.getAuthorizationUrl(request, auth, base64State);
if (method === "POST") {
return {
body: {
redirect: url,
},
};
}
return {
status: 302,
headers: {
Location: url,
},
};
}
getStateValue(query: URLSearchParams, name: string) {
if (query.get("state")) {
const state = Buffer.from(query.get("state")!, "base64").toString();
return state
.split(",")
.find((state) => state.startsWith(`${name}=`))
?.replace(`${name}=`, "");
}
}
async callback({ query, host }: ServerRequest, auth: Auth): Promise<CallbackResult> {
const code = query.get("code");
const redirect = this.getStateValue(query, "redirect");
const tokens = await this.getTokens(code!, this.getCallbackUri(auth, host));
let user = await this.getUserProfile(tokens);
if (this.config.profile) {
user = await this.config.profile(user, tokens);
}
return [user, redirect ?? this.getUri(auth, host, "/")];
}
}

View File

@ -1,59 +1,69 @@
import type { EndpointOutput, ServerRequest } from "@sveltejs/kit/types/endpoint"; import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { CallbackResult } from "../types"; import type { Auth } from "../auth";
import { Provider, ProviderConfig } from "./base"; import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base";
export interface OAuth2ProviderConfig extends ProviderConfig { export interface OAuth2ProviderConfig extends OAuth2BaseProviderConfig {
profile?: (profile: any, tokens: any) => any | Promise<any>; accessTokenUrl?: string;
authorizationUrl?: string;
profileUrl?: string;
clientId?: string;
clientSecret?: string;
scope: string | string[];
headers?: any;
authorizationParams?: any;
params: any;
grantType?: string;
responseType?: string;
} }
export abstract class OAuth2Provider<T extends OAuth2ProviderConfig> extends Provider<T> { const defaultConfig: Partial<OAuth2ProviderConfig> = {
abstract getSigninUrl(request: ServerRequest, state: string): string | Promise<string>; responseType: "code",
abstract getTokens(code: string, redirectUri: string): any | Promise<any>; grantType: "authorization_code",
abstract getUserProfile(tokens: any): any | Promise<any>; };
async signin(request: ServerRequest): Promise<EndpointOutput> { export class OAuth2Provider<
const { method, host, query } = request; T extends OAuth2ProviderConfig = OAuth2ProviderConfig,
const state = [`redirect=${query.get("redirect") ?? this.getUri(host, "/")}`].join(","); > extends OAuth2BaseProvider<T> {
const base64State = Buffer.from(state).toString("base64"); constructor(config: T) {
const url = await this.getSigninUrl(request, base64State); super({
...defaultConfig,
...config,
});
}
if (method === "POST") { getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string) {
return { const data = {
body: { state,
redirect: url, response_type: this.config.responseType,
}, client_id: this.config.clientId,
}; scope: Array.isArray(this.config.scope) ? this.config.scope.join(" ") : this.config.scope,
} redirect_uri: this.getCallbackUri(auth, host),
nonce: Math.round(Math.random() * 1000).toString(), // TODO: Generate random based on user values
return { ...(this.config.authorizationParams ?? {}),
status: 302,
headers: {
Location: url,
},
}; };
const url = `${this.config.authorizationUrl}?${new URLSearchParams(data)}`;
return url;
} }
getStateValue(query: URLSearchParams, name: string) { async getTokens(code: string, redirectUri: string) {
if (query.get("state")) { const data = {
const state = Buffer.from(query.get("state")!, "base64").toString(); code,
return state grant_type: this.config.grantType,
.split(",") client_id: this.config.clientId,
.find((state) => state.startsWith(`${name}=`)) redirect_uri: redirectUri,
?.replace(`${name}=`, ""); client_secret: this.config.clientSecret,
} ...(this.config.params ?? {}),
};
const res = await fetch(`${this.config.accessTokenUrl}?${new URLSearchParams(data)}`);
return await res.json();
} }
async callback({ query, host }: ServerRequest): Promise<CallbackResult> { async getUserProfile(tokens: any) {
const code = query.get("code"); const res = await fetch(this.config.profileUrl!, {
const redirect = this.getStateValue(query, "redirect"); headers: { Authorization: `${tokens.token_type} ${tokens.access_token}` },
});
const tokens = await this.getTokens(code!, this.getCallbackUri(host)); return await res.json();
let user = await this.getUserProfile(tokens);
if (this.config.profile) {
user = await this.config.profile(user, tokens);
}
return [user, redirect ?? this.getUri(host, "/")];
} }
} }

View File

@ -1,11 +1,9 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2";
interface RedditOAuthProviderConfig extends OAuth2ProviderConfig { interface RedditOAuth2ProviderConfig extends OAuth2ProviderConfig {
duration?: "temporary" | "permanent";
apiKey: string; apiKey: string;
apiSecret: string; apiSecret: string;
scope?: string;
duration?: "temporary" | "permanent";
} }
const redditProfileHandler = ({ const redditProfileHandler = ({
@ -54,46 +52,35 @@ const redditProfileHandler = ({
comment_karma, comment_karma,
}); });
const defaultConfig: Partial<RedditOAuthProviderConfig> = { const defaultConfig: Partial<RedditOAuth2ProviderConfig> = {
id: "reddit", id: "reddit",
scope: "identity", scope: "identity",
duration: "temporary", duration: "temporary",
profile: redditProfileHandler, profile: redditProfileHandler,
authorizationUrl: "https://www.reddit.com/api/v1/authorize",
accessTokenUrl: "https://www.reddit.com/api/v1/access_token",
profileUrl: "https://oauth.reddit.com/api/v1/me",
}; };
export class RedditOAuthProvider extends OAuth2Provider<RedditOAuthProviderConfig> { export class RedditOAuth2Provider extends OAuth2Provider<RedditOAuth2ProviderConfig> {
constructor(config: RedditOAuthProviderConfig) { constructor(config: RedditOAuth2ProviderConfig) {
super({ super({
...defaultConfig, ...defaultConfig,
...config, ...config,
clientId: config.apiKey,
clientSecret: config.apiSecret,
}); });
} }
static profileHandler = redditProfileHandler; static profileHandler = redditProfileHandler;
async getSigninUrl({ host }: ServerRequest, state: string) {
const endpoint = "https://www.reddit.com/api/v1/authorize";
const data = {
client_id: this.config.apiKey,
response_type: "code",
state,
redirect_uri: this.getCallbackUri(host),
duration: this.config.duration!,
scope: this.config.scope!,
};
const url = `${endpoint}?${new URLSearchParams(data)}`;
return url;
}
async getTokens(code: string, redirectUri: string) { async getTokens(code: string, redirectUri: string) {
const endpoint = "https://www.reddit.com/api/v1/access_token"; const endpoint = this.config.accessTokenUrl!;
const data = { const data = {
code, code,
redirect_uri: redirectUri, redirect_uri: redirectUri,
grant_type: "authorization_code", grant_type: this.config.grantType!,
}; };
const body = Object.entries(data) const body = Object.entries(data)
.map(([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}`) .map(([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}`)
@ -105,19 +92,11 @@ export class RedditOAuthProvider extends OAuth2Provider<RedditOAuthProviderConfi
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
Authorization: Authorization:
"Basic " + "Basic " +
Buffer.from(`${this.config.apiKey}:${this.config.apiSecret}`).toString("base64"), Buffer.from(`${this.config.clientId}:${this.config.clientSecret}`).toString("base64"),
}, },
body, body,
}); });
return await res.json(); return await res.json();
} }
async getUserProfile(tokens: any) {
const endpoint = "https://oauth.reddit.com/api/v1/me";
const res = await fetch(endpoint, {
headers: { Authorization: `${tokens.token_type} ${tokens.access_token}` },
});
return await res.json();
}
} }

View File

@ -1,8 +1,8 @@
import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; import type { ServerRequest } from "@sveltejs/kit/types/endpoint";
import type { CallbackResult } from "../types"; import type { CallbackResult } from "../types";
import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base";
interface TwitterAuthProviderConfig extends OAuth2ProviderConfig { interface TwitterAuthProviderConfig extends OAuth2BaseProviderConfig {
apiKey: string; apiKey: string;
apiSecret: string; apiSecret: string;
} }
@ -11,7 +11,7 @@ const defaultConfig: Partial<TwitterAuthProviderConfig> = {
id: "twitter", id: "twitter",
}; };
export class TwitterAuthProvider extends OAuth2Provider<TwitterAuthProviderConfig> { export class TwitterAuthProvider extends OAuth2BaseProvider<TwitterAuthProviderConfig> {
constructor(config: TwitterAuthProviderConfig) { constructor(config: TwitterAuthProviderConfig) {
super({ super({
...defaultConfig, ...defaultConfig,
@ -37,7 +37,7 @@ export class TwitterAuthProvider extends OAuth2Provider<TwitterAuthProviderConfi
}; };
} }
async getSigninUrl({ host }: ServerRequest) { async getAuthorizationUrl({ host }: ServerRequest) {
const endpoint = "https://api.twitter.com/oauth/authorize"; const endpoint = "https://api.twitter.com/oauth/authorize";
const { oauthToken } = await this.getRequestToken(host); const { oauthToken } = await this.getRequestToken(host);