diff --git a/app/src/lib/appAuth.ts b/app/src/lib/appAuth.ts index 046107c..7c8b1a1 100644 --- a/app/src/lib/appAuth.ts +++ b/app/src/lib/appAuth.ts @@ -46,7 +46,7 @@ export const appAuth = new SvelteKitAuth({ ...token, user: { ...token.user, - connections: { [provider]: account }, + connections: { ...token.user.connections, [provider]: account }, }, }; } diff --git a/app/yarn.lock b/app/yarn.lock index 6168086..ad83970 100644 --- a/app/yarn.lock +++ b/app/yarn.lock @@ -2095,6 +2095,12 @@ simple-swizzle@^0.2.2: dependencies: is-arrayish "^0.3.1" +"sk-auth@file:..": + version "0.1.1" + dependencies: + cookie "^0.4.1" + jsonwebtoken "^8.5.1" + "sk-auth@file:../": version "0.1.1" dependencies: diff --git a/src/auth.ts b/src/auth.ts index edecf04..f7250c1 100644 --- a/src/auth.ts +++ b/src/auth.ts @@ -105,7 +105,7 @@ export class Auth { provider: Provider, ): Promise { const { headers, host } = request; - const [profile, redirectUrl] = await provider.callback(request); + const [profile, redirectUrl] = await provider.callback(request, this); let token = (await this.getToken(headers)) ?? { user: {} }; if (this.config?.callbacks?.jwt) { @@ -129,7 +129,7 @@ export class Auth { async handleEndpoint(request: ServerRequest): Promise { const { path, headers, method, host } = request; - if (path === this.getPath("signout")) { + if (path === this.getPath("signout", host)) { const token = this.setToken(headers, {}); const jwt = this.signToken(token); @@ -163,7 +163,7 @@ export class Auth { ); if (provider) { if (match.groups.method === "signin") { - return await provider.signin(request); + return await provider.signin(request, this); } else { return await this.handleProviderCallback(request, provider); } diff --git a/src/providers/base.ts b/src/providers/base.ts index 7671264..aef5bb4 100644 --- a/src/providers/base.ts +++ b/src/providers/base.ts @@ -1,5 +1,6 @@ import type { EndpointOutput } from "@sveltejs/kit"; import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; +import type { Auth } from "../auth"; import type { CallbackResult } from "../types"; export interface ProviderConfig { @@ -14,19 +15,25 @@ export abstract class Provider { this.id = config.id!; } - getUri(host: string, path: string) { + getUri(svelteKitAuth: Auth, path: string, host?: string) { return `http://${host}${path}`; } - getCallbackUri(host: string) { - return this.getUri(host, `${"/api/auth/callback/"}${this.id}`); + getCallbackUri(svelteKitAuth: Auth, host?: string) { + return svelteKitAuth.getPath(`${"/api/auth/callback/"}${this.id}`, host); + } + + getSigninUri(svelteKitAuth: Auth, host?: string) { + return svelteKitAuth.getPath(`${"/api/auth/signin/"}${this.id}`, host); } abstract signin = Record, Body = unknown>( request: ServerRequest, + svelteKitAuth: Auth, ): EndpointOutput | Promise; abstract callback = Record, Body = unknown>( request: ServerRequest, + svelteKitAuth: Auth, ): CallbackResult | Promise; } diff --git a/src/providers/facebook.ts b/src/providers/facebook.ts index 594d819..8463937 100644 --- a/src/providers/facebook.ts +++ b/src/providers/facebook.ts @@ -1,78 +1,37 @@ -import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; -interface FacebookAuthProviderConfig extends OAuth2ProviderConfig { - clientId: string; - clientSecret: string; - userProfileFields?: string; - scope?: string; +interface FacebookOAuth2ProviderConfig extends OAuth2ProviderConfig { + userProfileFields?: string | string[]; } -const defaultConfig: Partial = { +const defaultConfig: Partial = { id: "facebook", - scope: "email public_profile user_link", - userProfileFields: - "id,name,first_name,last_name,middle_name,name_format,picture,short_name,email", + scope: ["email", "public_profile", "user_link"], + userProfileFields: [ + "id", + "name", + "first_name", + "last_name", + "middle_name", + "name_format", + "picture", + "short_name", + "email", + ], + profileUrl: "https://graph.facebook.com/me", }; -export class FacebookAuthProvider extends OAuth2Provider { - constructor(config: FacebookAuthProviderConfig) { +export class FacebookOAuth2Provider extends OAuth2Provider { + constructor(config: FacebookOAuth2ProviderConfig) { + const userProfileFields = config.userProfileFields || defaultConfig.userProfileFields; + const profileUrl = `${config.profileUrl || defaultConfig.profileUrl}?${ + Array.isArray(userProfileFields) ? userProfileFields.join(",") : userProfileFields + }`; + super({ ...defaultConfig, + profileUrl, ...config, }); } - - getSigninUrl({ host }: ServerRequest, state: string) { - const endpoint = "https://www.facebook.com/v10.0/dialog/oauth"; - - const data = { - client_id: this.config.clientId, - scope: this.config.scope!, - redirect_uri: this.getCallbackUri(host), - state, - }; - - const url = `${endpoint}?${new URLSearchParams(data)}`; - return url; - } - - async getTokens(code: string, redirectUri: string) { - const endpoint = "https://graph.facebook.com/v10.0/oauth/access_token"; - - const data = { - code, - client_id: this.config.clientId, - redirect_uri: redirectUri, - client_secret: this.config.clientSecret, - }; - - const res = await fetch(`${endpoint}?${new URLSearchParams(data)}`); - return await res.json(); - } - - async inspectToken(tokens: any) { - const endpoint = "https://graph.facebook.com/debug_token"; - - const data = { - input_token: tokens.access_token, - access_token: `${this.config.clientId}|${this.config.clientSecret}`, - }; - - const res = await fetch(`${endpoint}?${new URLSearchParams(data)}`); - return await res.json(); - } - - async getUserProfile(tokens: any) { - const inspectResult = await this.inspectToken(tokens); - const endpoint = `https://graph.facebook.com/v10.0/${inspectResult.data.user_id}`; - - const data = { - access_token: tokens.access_token, - fields: this.config.userProfileFields!, - }; - - const res = await fetch(`${endpoint}?${new URLSearchParams(data)}`); - return await res.json(); - } } diff --git a/src/providers/google.ts b/src/providers/google.ts index bb0bfd8..67deced 100644 --- a/src/providers/google.ts +++ b/src/providers/google.ts @@ -1,82 +1,18 @@ -import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; -interface GoogleOAuthProviderConfig extends OAuth2ProviderConfig { - clientId: string; - clientSecret: string; - discoveryDocument?: string; - scope?: string; -} - -const defaultConfig: Partial = { +const defaultConfig: Partial = { id: "google", - discoveryDocument: "https://accounts.google.com/.well-known/openid-configuration", - scope: "openid profile email", + scope: ["openid", "profile", "email"], + accessTokenUrl: "https://accounts.google.com/o/oauth2/token", + authorizationUrl: "https://accounts.google.com/o/oauth2/auth?response_type=code", + profileUrl: "https://openidconnect.googleapis.com/v1/userinfo", }; -export class GoogleOAuthProvider extends OAuth2Provider { - constructor(config: GoogleOAuthProviderConfig) { +export class GoogleOAuth2Provider extends OAuth2Provider { + constructor(config: OAuth2ProviderConfig) { super({ ...defaultConfig, ...config, }); } - - async getProviderMetadata() { - const res = await fetch(this.config.discoveryDocument!); - const metadata = await res.json(); - return metadata; - } - - async getEndpoint(key: string) { - const metadata = await this.getProviderMetadata(); - return metadata[key] as string; - } - - async getSigninUrl({ host }: ServerRequest, state: string) { - const authorizationEndpoint = await this.getEndpoint("authorization_endpoint"); - - const data = { - response_type: "code", - client_id: this.config.clientId, - scope: this.config.scope!, - redirect_uri: this.getCallbackUri(host), - state, - login_hint: "example@provider.com", - nonce: Math.round(Math.random() * 1000).toString(), // TODO: Generate random based on user values - }; - - const url = `${authorizationEndpoint}?${new URLSearchParams(data)}`; - return url; - } - - async getTokens(code: string, redirectUri: string) { - const tokenEndpoint = await this.getEndpoint("token_endpoint"); - - const data = { - code, - client_id: this.config.clientId, - client_secret: this.config.clientSecret, - redirect_uri: redirectUri, - grant_type: "authorization_code", - }; - - const res = await fetch(tokenEndpoint, { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify(data), - }); - - return await res.json(); - } - - async getUserProfile(tokens: any) { - const userProfileEndpoint = await this.getEndpoint("userinfo_endpoint"); - const res = await fetch(userProfileEndpoint, { - headers: { Authorization: `${tokens.token_type} ${tokens.access_token}` }, - }); - return await res.json(); - } } diff --git a/src/providers/index.ts b/src/providers/index.ts index 9516d40..38d23a7 100644 --- a/src/providers/index.ts +++ b/src/providers/index.ts @@ -1,6 +1,6 @@ export { Provider } from "./base"; -export { GoogleOAuthProvider } from "./google"; +export { GoogleOAuth2Provider as GoogleOAuthProvider } from "./google"; export { TwitterAuthProvider } from "./twitter"; -export { FacebookAuthProvider } from "./facebook"; -export { OAuth2Provider } from "./oauth2"; -export { RedditOAuthProvider } from "./reddit"; +export { FacebookOAuth2Provider as FacebookAuthProvider } from "./facebook"; +export { OAuth2BaseProvider as OAuth2Provider } from "./oauth2.base"; +export { RedditOAuth2Provider as RedditOAuthProvider } from "./reddit"; diff --git a/src/providers/oauth2.base.ts b/src/providers/oauth2.base.ts new file mode 100644 index 0000000..ff12fcd --- /dev/null +++ b/src/providers/oauth2.base.ts @@ -0,0 +1,60 @@ +import type { EndpointOutput, ServerRequest } from "@sveltejs/kit/types/endpoint"; +import type { Auth } from "../auth"; +import type { CallbackResult } from "../types"; +import { Provider, ProviderConfig } from "./base"; + +export interface OAuth2BaseProviderConfig extends ProviderConfig { + profile?: (profile: any, tokens: any) => any | Promise; +} + +export abstract class OAuth2BaseProvider extends Provider { + abstract getAuthorizationUrl(request: ServerRequest, auth: Auth, state: string): string | Promise; + abstract getTokens(code: string, redirectUri: string): any | Promise; + abstract getUserProfile(tokens: any): any | Promise; + + async signin(request: ServerRequest, auth: Auth): Promise { + const { method, host, query } = request; + const state = [`redirect=${query.get("redirect") ?? this.getUri(auth, host, "/")}`].join(","); + const base64State = Buffer.from(state).toString("base64"); + const url = await this.getAuthorizationUrl(request, auth, base64State); + + if (method === "POST") { + return { + body: { + redirect: url, + }, + }; + } + + return { + status: 302, + headers: { + Location: url, + }, + }; + } + + getStateValue(query: URLSearchParams, name: string) { + if (query.get("state")) { + const state = Buffer.from(query.get("state")!, "base64").toString(); + return state + .split(",") + .find((state) => state.startsWith(`${name}=`)) + ?.replace(`${name}=`, ""); + } + } + + async callback({ query, host }: ServerRequest, auth: Auth): Promise { + const code = query.get("code"); + const redirect = this.getStateValue(query, "redirect"); + + const tokens = await this.getTokens(code!, this.getCallbackUri(auth, host)); + let user = await this.getUserProfile(tokens); + + if (this.config.profile) { + user = await this.config.profile(user, tokens); + } + + return [user, redirect ?? this.getUri(auth, host, "/")]; + } +} diff --git a/src/providers/oauth2.ts b/src/providers/oauth2.ts index 463ef5c..60c948c 100644 --- a/src/providers/oauth2.ts +++ b/src/providers/oauth2.ts @@ -1,59 +1,69 @@ -import type { EndpointOutput, ServerRequest } from "@sveltejs/kit/types/endpoint"; -import type { CallbackResult } from "../types"; -import { Provider, ProviderConfig } from "./base"; +import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; +import type { Auth } from "../auth"; +import { OAuth2BaseProvider, OAuth2BaseProviderConfig } from "./oauth2.base"; -export interface OAuth2ProviderConfig extends ProviderConfig { - profile?: (profile: any, tokens: any) => any | Promise; +export interface OAuth2ProviderConfig extends OAuth2BaseProviderConfig { + accessTokenUrl?: string; + authorizationUrl?: string; + profileUrl?: string; + clientId?: string; + clientSecret?: string; + scope: string | string[]; + headers?: any; + authorizationParams?: any; + params: any; + grantType?: string; + responseType?: string; } -export abstract class OAuth2Provider extends Provider { - abstract getSigninUrl(request: ServerRequest, state: string): string | Promise; - abstract getTokens(code: string, redirectUri: string): any | Promise; - abstract getUserProfile(tokens: any): any | Promise; +const defaultConfig: Partial = { + responseType: "code", + grantType: "authorization_code", +}; - async signin(request: ServerRequest): Promise { - const { method, host, query } = request; - const state = [`redirect=${query.get("redirect") ?? this.getUri(host, "/")}`].join(","); - const base64State = Buffer.from(state).toString("base64"); - const url = await this.getSigninUrl(request, base64State); +export class OAuth2Provider< + T extends OAuth2ProviderConfig = OAuth2ProviderConfig, +> extends OAuth2BaseProvider { + constructor(config: T) { + super({ + ...defaultConfig, + ...config, + }); + } - if (method === "POST") { - return { - body: { - redirect: url, - }, - }; - } - - return { - status: 302, - headers: { - Location: url, - }, + getAuthorizationUrl({ host }: ServerRequest, auth: Auth, state: string) { + const data = { + state, + response_type: this.config.responseType, + client_id: this.config.clientId, + scope: Array.isArray(this.config.scope) ? this.config.scope.join(" ") : this.config.scope, + redirect_uri: this.getCallbackUri(auth, host), + nonce: Math.round(Math.random() * 1000).toString(), // TODO: Generate random based on user values + ...(this.config.authorizationParams ?? {}), }; + + const url = `${this.config.authorizationUrl}?${new URLSearchParams(data)}`; + return url; } - getStateValue(query: URLSearchParams, name: string) { - if (query.get("state")) { - const state = Buffer.from(query.get("state")!, "base64").toString(); - return state - .split(",") - .find((state) => state.startsWith(`${name}=`)) - ?.replace(`${name}=`, ""); - } + async getTokens(code: string, redirectUri: string) { + const data = { + code, + grant_type: this.config.grantType, + client_id: this.config.clientId, + redirect_uri: redirectUri, + client_secret: this.config.clientSecret, + ...(this.config.params ?? {}), + }; + + const res = await fetch(`${this.config.accessTokenUrl}?${new URLSearchParams(data)}`); + return await res.json(); } - async callback({ query, host }: ServerRequest): Promise { - const code = query.get("code"); - const redirect = this.getStateValue(query, "redirect"); - - const tokens = await this.getTokens(code!, this.getCallbackUri(host)); - let user = await this.getUserProfile(tokens); - - if (this.config.profile) { - user = await this.config.profile(user, tokens); - } - - return [user, redirect ?? this.getUri(host, "/")]; + async getUserProfile(tokens: any) { + const res = await fetch(this.config.profileUrl!, { + headers: { Authorization: `${tokens.token_type} ${tokens.access_token}` }, + }); + return await res.json(); } } diff --git a/src/providers/reddit.ts b/src/providers/reddit.ts index 885ab3d..81f9520 100644 --- a/src/providers/reddit.ts +++ b/src/providers/reddit.ts @@ -1,11 +1,9 @@ -import type { ServerRequest } from "@sveltejs/kit/types/endpoint"; import { OAuth2Provider, OAuth2ProviderConfig } from "./oauth2"; -interface RedditOAuthProviderConfig extends OAuth2ProviderConfig { +interface RedditOAuth2ProviderConfig extends OAuth2ProviderConfig { + duration?: "temporary" | "permanent"; apiKey: string; apiSecret: string; - scope?: string; - duration?: "temporary" | "permanent"; } const redditProfileHandler = ({ @@ -54,46 +52,35 @@ const redditProfileHandler = ({ comment_karma, }); -const defaultConfig: Partial = { +const defaultConfig: Partial = { id: "reddit", scope: "identity", duration: "temporary", profile: redditProfileHandler, + authorizationUrl: "https://www.reddit.com/api/v1/authorize", + accessTokenUrl: "https://www.reddit.com/api/v1/access_token", + profileUrl: "https://oauth.reddit.com/api/v1/me", }; -export class RedditOAuthProvider extends OAuth2Provider { - constructor(config: RedditOAuthProviderConfig) { +export class RedditOAuth2Provider extends OAuth2Provider { + constructor(config: RedditOAuth2ProviderConfig) { super({ ...defaultConfig, ...config, + clientId: config.apiKey, + clientSecret: config.apiSecret, }); } static profileHandler = redditProfileHandler; - async getSigninUrl({ host }: ServerRequest, state: string) { - const endpoint = "https://www.reddit.com/api/v1/authorize"; - - const data = { - client_id: this.config.apiKey, - response_type: "code", - state, - redirect_uri: this.getCallbackUri(host), - duration: this.config.duration!, - scope: this.config.scope!, - }; - - const url = `${endpoint}?${new URLSearchParams(data)}`; - return url; - } - async getTokens(code: string, redirectUri: string) { - const endpoint = "https://www.reddit.com/api/v1/access_token"; + const endpoint = this.config.accessTokenUrl!; const data = { code, redirect_uri: redirectUri, - grant_type: "authorization_code", + grant_type: this.config.grantType!, }; const body = Object.entries(data) .map(([key, value]) => `${encodeURIComponent(key)}=${encodeURIComponent(value)}`) @@ -105,19 +92,11 @@ export class RedditOAuthProvider extends OAuth2Provider = { id: "twitter", }; -export class TwitterAuthProvider extends OAuth2Provider { +export class TwitterAuthProvider extends OAuth2BaseProvider { constructor(config: TwitterAuthProviderConfig) { super({ ...defaultConfig, @@ -37,7 +37,7 @@ export class TwitterAuthProvider extends OAuth2Provider