From 82ea2c9222657ea581927c829ad5d66a300a6de6 Mon Sep 17 00:00:00 2001 From: scalarbot Date: Mon, 22 Jun 2026 23:14:50 +0000 Subject: [PATCH] feat: scalar-typescript-sdk-go@0.2.5 --- .gitignore | 3 + README.md | 237 +- aliases.go | 16 + api.md | 907 +++ authentication.go | 84 + client.go | 111 + default_http_client.go | 30 + field.go | 45 + go.mod | 13 + go.sum | 10 + internal/apierror/apierror.go | 50 + internal/apiform/encoder.go | 493 ++ internal/apiform/form.go | 5 + internal/apiform/richparam.go | 20 + internal/apiform/tag.go | 88 + internal/apijson/decoder.go | 688 ++ internal/apijson/encoder.go | 379 + internal/apijson/enum.go | 140 + internal/apijson/field.go | 23 + internal/apijson/port.go | 120 + internal/apijson/registry.go | 51 + internal/apijson/subfield.go | 67 + internal/apijson/tag.go | 85 + internal/apijson/union.go | 208 + internal/apiquery/encoder.go | 394 + internal/apiquery/query.go | 55 + internal/apiquery/richparam.go | 19 + internal/apiquery/tag.go | 44 + internal/encoding/json/decode.go | 1324 ++++ internal/encoding/json/encode.go | 1395 ++++ internal/encoding/json/fold.go | 48 + internal/encoding/json/indent.go | 182 + internal/encoding/json/scanner.go | 610 ++ internal/encoding/json/sentinel/null.go | 46 + internal/encoding/json/shims/shims.go | 113 + internal/encoding/json/stream.go | 512 ++ internal/encoding/json/tables.go | 218 + internal/encoding/json/tags.go | 38 + internal/encoding/json/time.go | 61 + internal/paramutil/field.go | 30 + internal/paramutil/union.go | 48 + internal/requestconfig/requestconfig.go | 735 ++ internal/testutil/testutil.go | 27 + internal/version.go | 5 + loginportals.go | 237 + namespaces.go | 40 + openapi.augmented.json | 6921 ++++++++++++++++ option/middleware.go | 80 + option/requestoption.go | 263 + packages/pagination/pagination.go | 293 + packages/param/encoder.go | 109 + packages/param/null.go | 19 + packages/param/option.go | 121 + packages/param/param.go | 186 + packages/respjson/respjson.go | 88 + packages/ssestream/ssestream.go | 314 + registry.go | 423 + rules.go | 224 + scalar-sdk.manifest.json | 9662 +++++++++++++++++++++++ scalardocs.go | 125 + schemas.go | 142 + schemasaccessgroup.go | 97 + schemasversion.go | 115 + shared.go | 412 + shared/constant/constants.go | 32 + teams.go | 40 + tests/smoke-test.go | 911 +++ themes.go | 164 + 68 files changed, 30794 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 aliases.go create mode 100644 api.md create mode 100644 authentication.go create mode 100644 client.go create mode 100644 default_http_client.go create mode 100644 field.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/apierror/apierror.go create mode 100644 internal/apiform/encoder.go create mode 100644 internal/apiform/form.go create mode 100644 internal/apiform/richparam.go create mode 100644 internal/apiform/tag.go create mode 100644 internal/apijson/decoder.go create mode 100644 internal/apijson/encoder.go create mode 100644 internal/apijson/enum.go create mode 100644 internal/apijson/field.go create mode 100644 internal/apijson/port.go create mode 100644 internal/apijson/registry.go create mode 100644 internal/apijson/subfield.go create mode 100644 internal/apijson/tag.go create mode 100644 internal/apijson/union.go create mode 100644 internal/apiquery/encoder.go create mode 100644 internal/apiquery/query.go create mode 100644 internal/apiquery/richparam.go create mode 100644 internal/apiquery/tag.go create mode 100644 internal/encoding/json/decode.go create mode 100644 internal/encoding/json/encode.go create mode 100644 internal/encoding/json/fold.go create mode 100644 internal/encoding/json/indent.go create mode 100644 internal/encoding/json/scanner.go create mode 100644 internal/encoding/json/sentinel/null.go create mode 100644 internal/encoding/json/shims/shims.go create mode 100644 internal/encoding/json/stream.go create mode 100644 internal/encoding/json/tables.go create mode 100644 internal/encoding/json/tags.go create mode 100644 internal/encoding/json/time.go create mode 100644 internal/paramutil/field.go create mode 100644 internal/paramutil/union.go create mode 100644 internal/requestconfig/requestconfig.go create mode 100644 internal/testutil/testutil.go create mode 100644 internal/version.go create mode 100644 loginportals.go create mode 100644 namespaces.go create mode 100644 openapi.augmented.json create mode 100644 option/middleware.go create mode 100644 option/requestoption.go create mode 100644 packages/pagination/pagination.go create mode 100644 packages/param/encoder.go create mode 100644 packages/param/null.go create mode 100644 packages/param/option.go create mode 100644 packages/param/param.go create mode 100644 packages/respjson/respjson.go create mode 100644 packages/ssestream/ssestream.go create mode 100644 registry.go create mode 100644 rules.go create mode 100644 scalar-sdk.manifest.json create mode 100644 scalardocs.go create mode 100644 schemas.go create mode 100644 schemasaccessgroup.go create mode 100644 schemasversion.go create mode 100644 shared.go create mode 100644 shared/constant/constants.go create mode 100644 teams.go create mode 100644 tests/smoke-test.go create mode 100644 themes.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c9b4592 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.idea/ +.vscode/ +*.swp diff --git a/README.md b/README.md index d15db55..335512d 100644 --- a/README.md +++ b/README.md @@ -1 +1,236 @@ -# go-sdk \ No newline at end of file +# Scalar API + +Generated Go SDK for Scalar API. +API for managing Scalar platform resources. + +## TypeScript SDK + +For TypeScript, we provide a SDK that makes using our API even easier. + +### Install + +```bash +npm add @scalar/sdk +``` + +### Get a Scalar API key + +Create an API key in your Scalar account: + +- Dashboard: https://dashboard.scalar.com/account +- Store it in `.env`, for example: + +```bash +SCALAR_API_KEY=your_personal_token +``` + +### Exchange your API key for an access token + +The personal token is not an access token. Exchange it first with `postv1AuthExchange`. + +If you use the personal token directly for authenticated API calls, the API returns `401 Invalid authentication token`. + +```ts +import { Scalar } from '@scalar/sdk' + +const scalar = new Scalar() + +const exchange = await scalar.auth.postv1AuthExchange({ + personalToken: process.env.SCALAR_API_KEY!, +}) + +const accessToken = exchange.accessToken +``` + +### Use the access token + +Construct a second client with bearer auth. Use this authenticated client for API calls. + +```ts +import { Scalar } from '@scalar/sdk' + +const scalar = new Scalar() + +const exchange = await scalar.auth.postv1AuthExchange({ + personalToken: process.env.SCALAR_API_KEY!, +}) + +const authedScalar = new Scalar({ + bearerAuth: exchange.accessToken, +}) +``` + +### Notes + +- The exchange request itself can be made from a client constructed with no arguments (`new Scalar()`). +- The exchanged access token is valid for 12 hours. +- Timestamps are Unix seconds. + +### Read more + +- [@scalar/sdk on npm](https://www.npmjs.com/package/@scalar/sdk) + +
+ +## Contents + +- [Installation](#installation) +- [Usage](#usage) +- [API Reference](./api.md) +- [Authentication](#authentication) +- [Errors](#errors) +- [Client Options](#client-options) +- [Request Options](#request-options) +- [Retries and Timeouts](#retries-and-timeouts) +- [Helpers](#helpers) +- [Logging](#logging) +- [Requirements](#requirements) + +
+ +## Installation + +```sh +go get scalar-api +``` + +
+ +## Usage + +```go +package main + +import ( + "context" + "fmt" + "os" + + sdk "scalar-api" + "scalar-api/option" +) + +func main() { + client := sdk.NewClient( + option.WithBearerAuth(os.Getenv("BEARER_AUTH")), + ) + + registry, err := client.Registry.ListAllAPIDocuments(context.Background()) + if err != nil { + panic(err) + } + fmt.Println(registry) +} +``` + +The examples in the following sections assume a `client` configured as shown above. + +See the [API reference](./api.md) for every available operation. + +
+ +## Authentication + +Pass credentials to the generated client constructor. Environment variables are read automatically when supported by the target runtime. + +| Option | Type | Default | Description | +| --- | --- | --- | --- | +| `option.WithBearerAuth` | `string \| provider` | - | Credential for the BearerAuth client option. Defaults to BEARER_AUTH. | + +Declared schemes: + +- `BearerAuth` bearer token + +
+ +## Errors + +Non-success responses return generated API errors. Error objects expose status, headers, response body, and request metadata where the target runtime supports it. + +```go +registry, err := client.Registry.ListAllAPIDocuments(context.Background()) +if err != nil { + var apiErr *sdk.Error + if errors.As(err, &apiErr) { + fmt.Println(apiErr.StatusCode, apiErr.RawJSON()) + } + panic(err) +} + +// imports: sdk "scalar-api", "errors", "fmt" +``` + +Documented error statuses: `400`, `401`, `403`, `404`, `422`, `500`. + +
+ +## Client Options + +Configure the generated client by setting any of these options when you create it. + +```go +client := sdk.NewClient( + option.WithBaseURL("https://api.example.com"), + option.WithMaxRetries(2), + option.WithTimeout(60*time.Second), +) + +// imports: sdk "scalar-api", "scalar-api/option", "time" +``` + +| Option | Type | Default | Description | +| --- | --- | --- | --- | +| `option.WithBearerAuth` | `func(string) option.RequestOption` | `os.Getenv("BEARER_AUTH")` | Credential for the BearerAuth client option. | +| `option.WithEnvironmentProduction` | `func() option.RequestOption` | - | Select the production API environment. | +| `option.WithEnvironmentLocal` | `func() option.RequestOption` | - | Select the local API environment. | +| `option.WithBaseURL` | `func(string) option.RequestOption` | `os.Getenv("SCALAR_BASE_URL")` | Override the default API base URL. | +| `option.WithTimeout` | `func(time.Duration) option.RequestOption` | - | Maximum time to wait for each request attempt. | +| `option.WithMaxRetries` | `func(int) option.RequestOption` | `2` | Number of retries for temporary failures. | +| `option.WithHTTPClient` | `func(option.HTTPClient) option.RequestOption` | - | Custom HTTP client or transport implementation. | + +
+ +## Request Options + +| Option | Type | Default | Description | +| --- | --- | --- | --- | +| `option.WithHeader` | `func(string, string) option.RequestOption` | - | Set a per-request header. | +| `option.WithQuery` | `func(string, string) option.RequestOption` | - | Set a per-request query parameter. | +| `option.WithRequestBody` | `func(string, any) option.RequestOption` | - | Override the serialized request body and content type. | +| `option.WithResponseInto` | `func(**http.Response) option.RequestOption` | - | Capture the raw HTTP response. | +| `option.WithResponseBodyInto` | `func(any) option.RequestOption` | - | Override the response deserialization target. | + +
+ +## Retries and Timeouts + +Generated clients support request timeouts and retry temporary failures such as network errors, 408, 409, 429, and 5xx responses. Retry delays honor `Retry-After` headers when present. Tune the retry and timeout client options shown above, or override them per request. + +
+ +## Helpers + +- Pass `option.WithResponseInto(&raw)` to capture the underlying `*http.Response` for a request. +- Use the generated `String`, `Int`, `Bool`, `Float`, `Time`, `Opt`, and `Ptr` helpers when setting optional params. + +
+ +## Logging + +- Wrap the HTTP client with `option.WithMiddleware(...)` to add request logging or tracing. + +
+ +## Requirements + +- Go 1.22 or newer + +Powered by Scalar. + + +## Contributions + +This SDK is generated programmatically. Manual edits to generated files will be +overwritten on the next build. + +### SDK created by [Scalar](https://www.scalar.com/?utm_source=scalar-typescript-sdk-go&utm_campaign=sdk) diff --git a/aliases.go b/aliases.go new file mode 100644 index 0000000..180e26e --- /dev/null +++ b/aliases.go @@ -0,0 +1,16 @@ +// Code generated by Scalar SDK Generator. DO NOT EDIT. + +package scalarapi + +import ( + "scalar-api/internal/apierror" + "scalar-api/packages/param" +) + +// aliased to make [param.APIUnion] private when embedding +type paramUnion = param.APIUnion + +// aliased to make [param.APIObject] private when embedding +type paramObj = param.APIObject + +type Error = apierror.Error diff --git a/api.md b/api.md new file mode 100644 index 0000000..b80e8d8 --- /dev/null +++ b/api.md @@ -0,0 +1,907 @@ +# Scalar Go API + +Complete reference of every operation, grouped by resource. See [the README](./README.md) for usage and configuration. + +## Contents + +- [`Registry`](#registry) + - [List all API Documents](#list-all-api-documents) + - [List API Documents in a namespace](#list-api-documents-in-a-namespace) + - [Create API Document](#create-api-document) + - [Update API Document metadata](#update-api-document-metadata) + - [Delete API Document](#delete-api-document) + - [Get API Document](#get-api-document) + - [Update API Document version](#update-api-document-version) + - [Delete API Document version](#delete-api-document-version) + - [Get API Document version metadata](#get-api-document-version-metadata) + - [Create API Document version](#create-api-document-version) + - [Add access group](#add-access-group) + - [Remove access group](#remove-access-group) +- [`Schemas`](#schemas) + - [List all shared components](#list-all-shared-components) + - [Create a shared component](#create-a-shared-component) + - [Update shared component metadata](#update-shared-component-metadata) + - [Delete a shared component](#delete-a-shared-component) + - [`Schemas Version`](#schemas-version) + - [Get a shared component document](#get-a-shared-component-document) + - [Delete a shared component version](#delete-a-shared-component-version) + - [Create a shared component version](#create-a-shared-component-version) + - [`Schemas AccessGroup`](#schemas-accessgroup) + - [Add shared component access group](#add-shared-component-access-group) + - [Remove shared component access group](#remove-shared-component-access-group) +- [`LoginPortals`](#loginportals) + - [Get a login portal](#get-a-login-portal) + - [Update portal metadata](#update-portal-metadata) + - [Delete a login portal](#delete-a-login-portal) + - [Create a portal](#create-a-portal) + - [List all portals](#list-all-portals) +- [`Rules`](#rules) + - [List all rules](#list-all-rules) + - [Create a rule](#create-a-rule) + - [Update rule metadata](#update-rule-metadata) + - [Delete a rule](#delete-a-rule) + - [Get a rule](#get-a-rule) + - [Add rule access group](#add-rule-access-group) + - [Remove rule access group](#remove-rule-access-group) +- [`Themes`](#themes) + - [List all themes](#list-all-themes) + - [Create a theme](#create-a-theme) + - [Update theme metadata](#update-theme-metadata) + - [Update theme document](#update-theme-document) + - [Delete a theme](#delete-a-theme) + - [Get a theme](#get-a-theme) +- [`Teams`](#teams) + - [List teams](#list-teams) +- [`ScalarDocs`](#scalardocs) + - [List all projects](#list-all-projects) + - [Create a project](#create-a-project) + - [Publish a project](#publish-a-project) +- [`Namespaces`](#namespaces) + - [List namespaces](#list-namespaces) +- [`Authentication`](#authentication) + - [Exchange token](#exchange-token) + - [Get current user](#get-current-user) + +## Setup + +```go +import ( + "context" + "fmt" + + sdk "scalar-api" +) + +client := sdk.NewClient() +``` + +## `Registry` + +### List all API Documents + +List all API documents across every namespace the caller can access. + +| Direction | Type | +| --- | --- | +| Response | [`[]APIDocument`](./shared.go) | + +```go +registry, err := client.Registry.ListAllAPIDocuments(context.Background()) +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### List API Documents in a namespace + +List API documents in a namespace. + +| Direction | Type | +| --- | --- | +| Response | [`[]APIDocument`](./shared.go) | + +```go +registry, err := client.Registry.ListAPIDocuments(context.Background(), "namespace") +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Create API Document + +Create an API document. + +| Direction | Type | +| --- | --- | +| Request | [`RegistryNewAPIDocumentParams`](./registry.go) | +| Response | [`RegistryNewAPIDocumentResponse`](./registry.go) | + +```go +registry, err := client.Registry.NewAPIDocument(context.Background(), "namespace", sdk.RegistryNewAPIDocumentParams{ + Document: "", + Slug: "", + Title: "", + Version: "", +}) +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Update API Document metadata + +Update metadata for an API document. + +| Direction | Type | +| --- | --- | +| Request | [`RegistryUpdateAPIDocumentParams`](./registry.go) | + +```go +registry, err := client.Registry.UpdateAPIDocument(context.Background(), "namespace", "slug", sdk.RegistryUpdateAPIDocumentParams{}) +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Delete API Document + +Delete an API document and all versions. + +```go +registry, err := client.Registry.DeleteAPIDocument(context.Background(), "namespace", "slug") +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Get API Document + +Get a specific API document version. + +| Direction | Type | +| --- | --- | +| Response | `string` | + +```go +registry, err := client.Registry.GetAPIDocumentVersion(context.Background(), "namespace", "slug", "semver") +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Update API Document version + +Update the registry file content for an API document version. + +| Direction | Type | +| --- | --- | +| Request | [`RegistryUpdateAPIDocumentVersionParams`](./registry.go) | +| Response | [`RegistryUpdateAPIDocumentVersionResponse`](./registry.go) | + +```go +registry, err := client.Registry.UpdateAPIDocumentVersion(context.Background(), "namespace", "slug", "semver", sdk.RegistryUpdateAPIDocumentVersionParams{ + Document: "", +}) +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Delete API Document version + +Delete a specific API document version. + +```go +registry, err := client.Registry.DeleteAPIDocumentVersion(context.Background(), "namespace", "slug", "semver") +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Get API Document version metadata + +Get metadata (uid, content shas, version sha, tags) for a specific API document version. + +| Direction | Type | +| --- | --- | +| Response | [`ManagedDocVersion`](./registry.go) | + +```go +registry, err := client.Registry.ListAPIDocumentVersionMetadata(context.Background(), "namespace", "slug", "semver") +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Create API Document version + +Create a new API document version. + +| Direction | Type | +| --- | --- | +| Request | [`RegistryNewAPIDocumentVersionParams`](./registry.go) | +| Response | [`ManagedDocVersion`](./registry.go) | + +```go +registry, err := client.Registry.NewAPIDocumentVersion(context.Background(), "namespace", "slug", sdk.RegistryNewAPIDocumentVersionParams{ + Document: "", + Version: "", +}) +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Add access group + +Add an access group to an API document. + +| Direction | Type | +| --- | --- | +| Request | [`RegistryNewAPIDocumentAccessGroupParams`](./registry.go) | + +```go +registry, err := client.Registry.NewAPIDocumentAccessGroup(context.Background(), "namespace", "slug", sdk.RegistryNewAPIDocumentAccessGroupParams{ + AccessGroup: sdk.AccessGroup{ + AccessGroupSlug: "", +}, +}) +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +### Remove access group + +Remove an access group from an API document. + +| Direction | Type | +| --- | --- | +| Request | [`RegistryDeleteAPIDocumentAccessGroupParams`](./registry.go) | + +```go +registry, err := client.Registry.DeleteAPIDocumentAccessGroup(context.Background(), "namespace", "slug", sdk.RegistryDeleteAPIDocumentAccessGroupParams{ + AccessGroup: sdk.AccessGroup{ + AccessGroupSlug: "", +}, +}) +if err != nil { + panic(err) +} +fmt.Println(registry) +``` + +## `Schemas` + +### List all shared components + +List schemas in a namespace. + +| Direction | Type | +| --- | --- | +| Response | [`[]Schema`](./shared.go) | + +```go +schema, err := client.Schemas.List(context.Background(), "namespace") +if err != nil { + panic(err) +} +fmt.Println(schema) +``` + +### Create a shared component + +Create a schema in a namespace. + +| Direction | Type | +| --- | --- | +| Request | [`SchemaNewParams`](./schemas.go) | +| Response | [`UID`](./shared.go) | + +```go +schema, err := client.Schemas.New(context.Background(), "namespace", sdk.SchemaNewParams{ + Document: "", + Slug: "", + Title: "", + Version: "", +}) +if err != nil { + panic(err) +} +fmt.Println(schema) +``` + +### Update shared component metadata + +Update schema metadata. + +| Direction | Type | +| --- | --- | +| Request | [`SchemaUpdateParams`](./schemas.go) | + +```go +schema, err := client.Schemas.Update(context.Background(), "namespace", "slug", sdk.SchemaUpdateParams{}) +if err != nil { + panic(err) +} +fmt.Println(schema) +``` + +### Delete a shared component + +Delete a schema and all related versions. + +```go +schema, err := client.Schemas.Delete(context.Background(), "namespace", "slug") +if err != nil { + panic(err) +} +fmt.Println(schema) +``` + +### `Schemas Version` + +#### Get a shared component document + +Get a specific schema version document. + +| Direction | Type | +| --- | --- | +| Response | `string` | + +```go +version, err := client.Schemas.Version.GetSchema(context.Background(), "namespace", "slug", "semver") +if err != nil { + panic(err) +} +fmt.Println(version) +``` + +#### Delete a shared component version + +Delete a schema version. + +```go +version, err := client.Schemas.Version.DeleteSchema(context.Background(), "namespace", "slug", "semver") +if err != nil { + panic(err) +} +fmt.Println(version) +``` + +#### Create a shared component version + +Create a schema version. + +| Direction | Type | +| --- | --- | +| Request | [`SchemaVersionNewSchemaParams`](./schemasversion.go) | +| Response | [`UID`](./shared.go) | + +```go +version, err := client.Schemas.Version.NewSchema(context.Background(), "namespace", "slug", sdk.SchemaVersionNewSchemaParams{ + Document: "", + Version: "", +}) +if err != nil { + panic(err) +} +fmt.Println(version) +``` + +### `Schemas AccessGroup` + +#### Add shared component access group + +Add an access group to a schema. + +| Direction | Type | +| --- | --- | +| Request | [`SchemaAccessGroupNewSchemaParams`](./schemasaccessgroup.go) | + +```go +accessGroup, err := client.Schemas.AccessGroup.NewSchema(context.Background(), "namespace", "slug", sdk.SchemaAccessGroupNewSchemaParams{ + AccessGroup: sdk.AccessGroup{ + AccessGroupSlug: "", +}, +}) +if err != nil { + panic(err) +} +fmt.Println(accessGroup) +``` + +#### Remove shared component access group + +Remove an access group from a schema. + +| Direction | Type | +| --- | --- | +| Request | [`SchemaAccessGroupDeleteSchemaParams`](./schemasaccessgroup.go) | + +```go +accessGroup, err := client.Schemas.AccessGroup.DeleteSchema(context.Background(), "namespace", "slug", sdk.SchemaAccessGroupDeleteSchemaParams{ + AccessGroup: sdk.AccessGroup{ + AccessGroupSlug: "", +}, +}) +if err != nil { + panic(err) +} +fmt.Println(accessGroup) +``` + +## `LoginPortals` + +### Get a login portal + +Get a login portal by slug. + +| Direction | Type | +| --- | --- | +| Response | [`LoginPortalGetResponse`](./loginportals.go) | + +```go +loginPortal, err := client.LoginPortals.Get(context.Background(), "slug") +if err != nil { + panic(err) +} +fmt.Println(loginPortal) +``` + +### Update portal metadata + +Update metadata for a login portal. + +| Direction | Type | +| --- | --- | +| Request | [`LoginPortalUpdateParams`](./loginportals.go) | + +```go +loginPortal, err := client.LoginPortals.Update(context.Background(), "slug", sdk.LoginPortalUpdateParams{}) +if err != nil { + panic(err) +} +fmt.Println(loginPortal) +``` + +### Delete a login portal + +Delete a login portal. + +```go +loginPortal, err := client.LoginPortals.Delete(context.Background(), "slug") +if err != nil { + panic(err) +} +fmt.Println(loginPortal) +``` + +### Create a portal + +Create a login portal for the current team. + +| Direction | Type | +| --- | --- | +| Request | [`LoginPortalNewParams`](./loginportals.go) | +| Response | [`UID`](./shared.go) | + +```go +loginPortal, err := client.LoginPortals.New(context.Background(), sdk.LoginPortalNewParams{ + Email: sdk.LoginPortalEmail{ + Logo: "", + LogoSize: "100", + ButtonText: "Login", + Message: "Click to access private documentation hosted by scalar.com", + Title: "Private Docs", + MainColor: "#2a2f45", + MainBackground: "#f6f6f6", + CardColor: "2a2f45", + CardBackground: "#fff", + ButtonColor: "#fff", + ButtonBackground: "#0f0f0f", + }, + Page: sdk.LoginPortalPage{ + Title: "Scalar Private Docs", + Description: "Login to access your documentation", + Head: "", + Script: "", + Theme: "", + CompanyName: "", + Logo: "", + LogoURL: "", + Favicon: "", + TermsLink: "", + PrivacyLink: "", + FormTitle: "Scalar Private Docs", + FormDescription: "Login to access your documentation", + FormImage: "", + }, + Slug: "", + Title: "", +}) +if err != nil { + panic(err) +} +fmt.Println(loginPortal) +``` + +### List all portals + +List all login portals for the current team. + +| Direction | Type | +| --- | --- | +| Response | [`[]LoginPortal`](./shared.go) | + +```go +loginPortal, err := client.LoginPortals.List(context.Background()) +if err != nil { + panic(err) +} +fmt.Println(loginPortal) +``` + +## `Rules` + +### List all rules + +List all rulesets in a namespace. + +| Direction | Type | +| --- | --- | +| Response | [`[]Rule`](./shared.go) | + +```go +rule, err := client.Rules.ListRulesets(context.Background(), "namespace") +if err != nil { + panic(err) +} +fmt.Println(rule) +``` + +### Create a rule + +Create a rule in a namespace. + +| Direction | Type | +| --- | --- | +| Request | [`RuleNewRulesetParams`](./rules.go) | +| Response | [`UID`](./shared.go) | + +```go +rule, err := client.Rules.NewRuleset(context.Background(), "namespace", sdk.RuleNewRulesetParams{ + Document: "", + Slug: "", + Title: "", +}) +if err != nil { + panic(err) +} +fmt.Println(rule) +``` + +### Update rule metadata + +Update rule metadata by slug. + +| Direction | Type | +| --- | --- | +| Request | [`RuleUpdateRulesetParams`](./rules.go) | + +```go +rule, err := client.Rules.UpdateRuleset(context.Background(), "namespace", "slug", sdk.RuleUpdateRulesetParams{}) +if err != nil { + panic(err) +} +fmt.Println(rule) +``` + +### Delete a rule + +Delete a rule by slug. + +```go +rule, err := client.Rules.DeleteRuleset(context.Background(), "namespace", "slug") +if err != nil { + panic(err) +} +fmt.Println(rule) +``` + +### Get a rule + +Get a rule document by slug. + +| Direction | Type | +| --- | --- | +| Response | `string` | + +```go +rule, err := client.Rules.GetRulesetDocument(context.Background(), "namespace", "slug") +if err != nil { + panic(err) +} +fmt.Println(rule) +``` + +### Add rule access group + +Grant an access group to a rule. + +| Direction | Type | +| --- | --- | +| Request | [`RuleNewRulesetAccessGroupParams`](./rules.go) | + +```go +rule, err := client.Rules.NewRulesetAccessGroup(context.Background(), "namespace", "slug", sdk.RuleNewRulesetAccessGroupParams{ + AccessGroup: sdk.AccessGroup{ + AccessGroupSlug: "", +}, +}) +if err != nil { + panic(err) +} +fmt.Println(rule) +``` + +### Remove rule access group + +Remove an access group from a rule. + +| Direction | Type | +| --- | --- | +| Request | [`RuleDeleteRulesetAccessGroupParams`](./rules.go) | + +```go +rule, err := client.Rules.DeleteRulesetAccessGroup(context.Background(), "namespace", "slug", sdk.RuleDeleteRulesetAccessGroupParams{ + AccessGroup: sdk.AccessGroup{ + AccessGroupSlug: "", +}, +}) +if err != nil { + panic(err) +} +fmt.Println(rule) +``` + +## `Themes` + +### List all themes + +List all team themes. + +| Direction | Type | +| --- | --- | +| Response | [`[]Theme`](./shared.go) | + +```go +theme, err := client.Themes.List(context.Background()) +if err != nil { + panic(err) +} +fmt.Println(theme) +``` + +### Create a theme + +Create a team theme. + +| Direction | Type | +| --- | --- | +| Request | [`ThemeNewParams`](./themes.go) | +| Response | [`UID`](./shared.go) | + +```go +theme, err := client.Themes.New(context.Background(), sdk.ThemeNewParams{ + Document: "", + Name: "", + Slug: "", +}) +if err != nil { + panic(err) +} +fmt.Println(theme) +``` + +### Update theme metadata + +Update theme metadata. + +| Direction | Type | +| --- | --- | +| Request | [`ThemeUpdateParams`](./themes.go) | + +```go +theme, err := client.Themes.Update(context.Background(), "slug", sdk.ThemeUpdateParams{}) +if err != nil { + panic(err) +} +fmt.Println(theme) +``` + +### Update theme document + +Replace the theme document. + +| Direction | Type | +| --- | --- | +| Request | [`ThemeReplaceDocumentParams`](./themes.go) | + +```go +theme, err := client.Themes.ReplaceDocument(context.Background(), "slug", sdk.ThemeReplaceDocumentParams{ + Document: "", +}) +if err != nil { + panic(err) +} +fmt.Println(theme) +``` + +### Delete a theme + +Delete a theme by slug. + +```go +theme, err := client.Themes.Delete(context.Background(), "slug") +if err != nil { + panic(err) +} +fmt.Println(theme) +``` + +### Get a theme + +Get the theme document by slug. + +| Direction | Type | +| --- | --- | +| Response | `string` | + +```go +theme, err := client.Themes.Get(context.Background(), "slug") +if err != nil { + panic(err) +} +fmt.Println(theme) +``` + +## `Teams` + +### List teams + +List all available teams + +| Direction | Type | +| --- | --- | +| Response | [`[]Team`](./shared.go) | + +```go +team, err := client.Teams.List(context.Background()) +if err != nil { + panic(err) +} +fmt.Println(team) +``` + +## `ScalarDocs` + +### List all projects + +List all guide projects. + +| Direction | Type | +| --- | --- | +| Response | [`[]GithubProject`](./shared.go) | + +```go +scalarDoc, err := client.ScalarDocs.ListGuides(context.Background()) +if err != nil { + panic(err) +} +fmt.Println(scalarDoc) +``` + +### Create a project + +Create a guide project. + +| Direction | Type | +| --- | --- | +| Request | [`ScalarDocNewGuideParams`](./scalardocs.go) | +| Response | [`ScalarDocNewGuideResponse`](./scalardocs.go) | + +```go +scalarDoc, err := client.ScalarDocs.NewGuide(context.Background(), sdk.ScalarDocNewGuideParams{ + AllowedDomains: []string{""}, + AllowedUsers: []string{""}, + IsPrivate: false, + Name: "", +}) +if err != nil { + panic(err) +} +fmt.Println(scalarDoc) +``` + +### Publish a project + +Start a new publish process. + +| Direction | Type | +| --- | --- | +| Response | [`ScalarDocPublishGuideResponse`](./scalardocs.go) | + +```go +scalarDoc, err := client.ScalarDocs.PublishGuide(context.Background(), "slug") +if err != nil { + panic(err) +} +fmt.Println(scalarDoc) +``` + +## `Namespaces` + +### List namespaces + +Get all namespaces for the current team + +| Direction | Type | +| --- | --- | +| Response | `[]string` | + +```go +namespace, err := client.Namespaces.List(context.Background()) +if err != nil { + panic(err) +} +fmt.Println(namespace) +``` + +## `Authentication` + +### Exchange token + +Exchange an API key for an access token. + +| Direction | Type | +| --- | --- | +| Request | [`AuthenticationExchangePersonalTokenParams`](./authentication.go) | +| Response | [`AuthenticationExchangePersonalTokenResponse`](./authentication.go) | + +```go +authentication, err := client.Authentication.ExchangePersonalToken(context.Background(), sdk.AuthenticationExchangePersonalTokenParams{ + PersonalToken: "", +}) +if err != nil { + panic(err) +} +fmt.Println(authentication) +``` + +### Get current user + +Get the authenticated user, including their available teams and theme. + +| Direction | Type | +| --- | --- | +| Response | [`User`](./shared.go) | + +```go +authentication, err := client.Authentication.ListCurrentUser(context.Background()) +if err != nil { + panic(err) +} +fmt.Println(authentication) +``` diff --git a/authentication.go b/authentication.go new file mode 100644 index 0000000..bde215e --- /dev/null +++ b/authentication.go @@ -0,0 +1,84 @@ +// Code generated by Scalar SDK Generator. DO NOT EDIT. + +package scalarapi + +import ( + "context" + "net/http" + "slices" + + "scalar-api/internal/apijson" + "scalar-api/internal/requestconfig" + "scalar-api/option" + "scalar-api/packages/param" + "scalar-api/packages/respjson" +) + +// AuthenticationService contains methods and other services that help with interacting +// with the API. You should not instantiate this service directly, and instead use +// the [NewAuthenticationService] method instead. +type AuthenticationService struct { + options []option.RequestOption +} + +// NewAuthenticationService generates a new service that applies the given options to each request. +// These options are applied after the parent client's options (if there is one), and +// before any request-specific options. +func NewAuthenticationService(opts ...option.RequestOption) AuthenticationService { + s := AuthenticationService{} + s.options = opts + return s +} + +// Exchange an API key for an access token. +func (r *AuthenticationService) ExchangePersonalToken(ctx context.Context, body AuthenticationExchangePersonalTokenParams, opts ...option.RequestOption) (res *AuthenticationExchangePersonalTokenResponse, err error) { + opts = slices.Concat(r.options, opts) + requestPath := "v1/auth/exchange" + res = new(AuthenticationExchangePersonalTokenResponse) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, requestPath, body, &res, opts...) + if err != nil { + return nil, err + } + return res, nil +} + +type AuthenticationExchangePersonalTokenParams struct { + PersonalToken string `json:"personalToken" api:"required"` + paramObj +} + +func (r AuthenticationExchangePersonalTokenParams) MarshalJSON() (data []byte, err error) { + type shadow AuthenticationExchangePersonalTokenParams + return param.MarshalObject(r, (*shadow)(&r)) +} +func (r *AuthenticationExchangePersonalTokenParams) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +// Get the authenticated user, including their available teams and theme. +func (r *AuthenticationService) ListCurrentUser(ctx context.Context, opts ...option.RequestOption) (res *User, err error) { + opts = slices.Concat(r.options, opts) + requestPath := "v1/auth/me" + res = new(User) + err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, requestPath, nil, &res, opts...) + if err != nil { + return nil, err + } + return res, nil +} + +type AuthenticationExchangePersonalTokenResponse struct { + AccessToken string `json:"accessToken" api:"required"` + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + AccessToken respjson.Field + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` +} + +// RawJSON returns the unmodified JSON received from the API. +func (r AuthenticationExchangePersonalTokenResponse) RawJSON() string { return r.JSON.raw } +func (r *AuthenticationExchangePersonalTokenResponse) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} diff --git a/client.go b/client.go new file mode 100644 index 0000000..1e9e005 --- /dev/null +++ b/client.go @@ -0,0 +1,111 @@ +// Code generated by Scalar SDK Generator. DO NOT EDIT. + +package scalarapi + +import ( + "context" + "net/http" + "os" + "slices" + "strings" + + "scalar-api/internal/requestconfig" + "scalar-api/option" +) + +// Client creates a struct with services and top level methods that help with +// interacting with the Scalar API API. You should not instantiate this client +// directly, and instead use the [NewClient] method instead. +type Client struct { + options []option.RequestOption + Registry RegistryService + Schemas SchemaService + LoginPortals LoginPortalService + Rules RuleService + Themes ThemeService + Teams TeamService + ScalarDocs ScalarDocService + Namespaces NamespaceService + Authentication AuthenticationService +} + +// DefaultClientOptions read from the environment. This should be used to initialize +// new clients. +func DefaultClientOptions() []option.RequestOption { + defaults := []option.RequestOption{option.WithHTTPClient(defaultHTTPClient()), option.WithEnvironmentProduction()} + if o, ok := os.LookupEnv("SCALAR_BASE_URL"); ok && o != "" { + defaults = append(defaults, option.WithBaseURL(o)) + } + if o, ok := os.LookupEnv("BEARER_AUTH"); ok && o != "" { + defaults = append(defaults, option.WithBearerAuth(o)) + } + if o, ok := os.LookupEnv("SCALAR_CUSTOM_HEADERS"); ok { + for _, line := range strings.Split(o, "\n") { + colon := strings.Index(line, ":") + if colon >= 0 { + defaults = append(defaults, option.WithHeader(strings.TrimSpace(line[:colon]), strings.TrimSpace(line[colon+1:]))) + } + } + } + return defaults +} + +// NewClient generates a new client with the default option read from the +// environment. The option passed in as arguments are applied after these default +// arguments, and all option will be passed down to the services and requests that +// this client makes. +func NewClient(opts ...option.RequestOption) Client { + opts = append(DefaultClientOptions(), opts...) + + client := Client{options: opts} + + client.Registry = NewRegistryService(opts...) + client.Schemas = NewSchemaService(opts...) + client.LoginPortals = NewLoginPortalService(opts...) + client.Rules = NewRuleService(opts...) + client.Themes = NewThemeService(opts...) + client.Teams = NewTeamService(opts...) + client.ScalarDocs = NewScalarDocService(opts...) + client.Namespaces = NewNamespaceService(opts...) + client.Authentication = NewAuthenticationService(opts...) + + return client +} + + +// Execute makes a request with the given context, method, URL, request params, +// response, and request options. +func (r *Client) Execute(ctx context.Context, method string, path string, params any, res any, opts ...option.RequestOption) error { + opts = slices.Concat(r.options, opts) + return requestconfig.ExecuteNewRequest(ctx, method, path, params, res, opts...) +} + +// Get makes a GET request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and response. +func (r *Client) Get(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodGet, path, params, res, opts...) +} + +// Post makes a POST request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and response. +func (r *Client) Post(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPost, path, params, res, opts...) +} + +// Put makes a PUT request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and response. +func (r *Client) Put(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPut, path, params, res, opts...) +} + +// Patch makes a PATCH request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and response. +func (r *Client) Patch(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodPatch, path, params, res, opts...) +} + +// Delete makes a DELETE request with the given URL, params, and optionally +// deserializes to a response. See [Execute] documentation on the params and response. +func (r *Client) Delete(ctx context.Context, path string, params any, res any, opts ...option.RequestOption) error { + return r.Execute(ctx, http.MethodDelete, path, params, res, opts...) +} diff --git a/default_http_client.go b/default_http_client.go new file mode 100644 index 0000000..fe8e62a --- /dev/null +++ b/default_http_client.go @@ -0,0 +1,30 @@ +// File generated from our OpenAPI spec by Scalar. See README.md for details. + +package scalarapi + +import ( + "net/http" + "time" +) + +// defaultResponseHeaderTimeout bounds the time between a fully written request +// and the server's response headers. It does not apply to the response body, +// so long-running streams are unaffected. Without this, a server that accepts +// the connection but never responds would hang the request indefinitely. +const defaultResponseHeaderTimeout = 10 * time.Minute + +// defaultHTTPClient returns an [*http.Client] used when the caller does not +// supply one via [option.WithHTTPClient]. When [http.DefaultTransport] is the +// stdlib [*http.Transport], it is cloned and a [http.Transport.ResponseHeaderTimeout] +// is set so stuck connections fail fast instead of compounding across retries. +// If [http.DefaultTransport] has been wrapped (for example by otelhttp for +// distributed tracing), the wrapping is preserved and the header timeout is +// skipped. +func defaultHTTPClient() *http.Client { + if t, ok := http.DefaultTransport.(*http.Transport); ok { + t = t.Clone() + t.ResponseHeaderTimeout = defaultResponseHeaderTimeout + return &http.Client{Transport: t} + } + return &http.Client{Transport: http.DefaultTransport} +} diff --git a/field.go b/field.go new file mode 100644 index 0000000..cdfdfa9 --- /dev/null +++ b/field.go @@ -0,0 +1,45 @@ +package scalarapi + +import ( + "scalar-api/packages/param" + "io" + "time" +) + +func String(s string) param.Opt[string] { return param.NewOpt(s) } +func Int(i int64) param.Opt[int64] { return param.NewOpt(i) } +func Bool(b bool) param.Opt[bool] { return param.NewOpt(b) } +func Float(f float64) param.Opt[float64] { return param.NewOpt(f) } +func Time(t time.Time) param.Opt[time.Time] { return param.NewOpt(t) } + +func Opt[T comparable](v T) param.Opt[T] { return param.NewOpt(v) } +func Ptr[T any](v T) *T { return &v } + +func IntPtr(v int64) *int64 { return &v } +func BoolPtr(v bool) *bool { return &v } +func FloatPtr(v float64) *float64 { return &v } +func StringPtr(v string) *string { return &v } +func TimePtr(v time.Time) *time.Time { return &v } + +func File(rdr io.Reader, filename string, contentType string) file { + return file{rdr, filename, contentType} +} + +type file struct { + io.Reader + name string + contentType string +} + +func (f file) Filename() string { + if f.name != "" { + return f.name + } else if named, ok := f.Reader.(interface{ Name() string }); ok { + return named.Name() + } + return "" +} + +func (f file) ContentType() string { + return f.contentType +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..df4393c --- /dev/null +++ b/go.mod @@ -0,0 +1,13 @@ +module scalar-api + +go 1.22 + +require ( + github.com/tidwall/gjson v1.18.0 + github.com/tidwall/sjson v1.2.5 +) + +require ( + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..32ba293 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= diff --git a/internal/apierror/apierror.go b/internal/apierror/apierror.go new file mode 100644 index 0000000..8db729b --- /dev/null +++ b/internal/apierror/apierror.go @@ -0,0 +1,50 @@ +// File generated from our OpenAPI spec by Scalar. See README.md for details. + +package apierror + +import ( + "fmt" + "net/http" + "net/http/httputil" + + "scalar-api/internal/apijson" + "scalar-api/packages/respjson" +) + +// Error represents an error that originates from the API, i.e. when a request is +// made and the API returns a response with a HTTP status code. Other errors are +// not wrapped by this SDK. +type Error struct { + // JSON contains metadata for fields, check presence with [respjson.Field.Valid]. + JSON struct { + ExtraFields map[string]respjson.Field + raw string + } `json:"-"` + StatusCode int + Request *http.Request + Response *http.Response +} + +// RawJSON returns the unmodified JSON received from the API. +func (r Error) RawJSON() string { return r.JSON.raw } +func (r *Error) UnmarshalJSON(data []byte) error { + return apijson.UnmarshalRoot(data, r) +} + +func (r *Error) Error() string { + // Attempt to re-populate the response body + return fmt.Sprintf("%s %q: %d %s %s", r.Request.Method, r.Request.URL, r.Response.StatusCode, http.StatusText(r.Response.StatusCode), r.JSON.raw) +} + +func (r *Error) DumpRequest(body bool) []byte { + if r.Request.GetBody != nil { + r.Request.Body, _ = r.Request.GetBody() + } + out, _ := httputil.DumpRequestOut(r.Request, body) + return out +} + +func (r *Error) DumpResponse(body bool) []byte { + out, _ := httputil.DumpResponse(r.Response, body) + return out +} diff --git a/internal/apiform/encoder.go b/internal/apiform/encoder.go new file mode 100644 index 0000000..737ccb0 --- /dev/null +++ b/internal/apiform/encoder.go @@ -0,0 +1,493 @@ +package apiform + +import ( + "fmt" + "io" + "mime/multipart" + "net/textproto" + "path" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "scalar-api/packages/param" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +func Marshal(value any, writer *multipart.Writer) error { + e := &encoder{ + dateFormat: time.RFC3339, + arrayFmt: "repeat", + } + return e.marshal(value, writer) +} + +func MarshalRoot(value any, writer *multipart.Writer) error { + e := &encoder{ + root: true, + dateFormat: time.RFC3339, + arrayFmt: "repeat", + } + return e.marshal(value, writer) +} + +func MarshalWithSettings(value any, writer *multipart.Writer, arrayFormat string) error { + e := &encoder{ + arrayFmt: arrayFormat, + dateFormat: time.RFC3339, + } + return e.marshal(value, writer) +} + +type encoder struct { + arrayFmt string + dateFormat string + root bool +} + +type encoderFunc func(key string, value reflect.Value, writer *multipart.Writer) error + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + typ reflect.Type + dateFormat string + arrayFmt string + root bool +} + +func (e *encoder) marshal(value any, writer *multipart.Writer) error { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc("", val, writer) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + typ: t, + dateFormat: e.dateFormat, + arrayFmt: e.arrayFmt, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value, writer *multipart.Writer) error { + wg.Wait() + return f(key, v, writer) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if t.Implements(reflect.TypeOf((*io.Reader)(nil)).Elem()) { + return e.newReaderTypeEncoder() + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if !v.IsValid() || v.IsNil() { + return nil + } + return innerEncoder(key, v.Elem(), writer) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Slice, reflect.Array: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, v.String()) + } + case reflect.Bool: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if v.Bool() { + return writer.WriteField(key, "true") + } + return writer.WriteField(key, "false") + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatInt(v.Int(), 10)) + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatUint(v.Uint(), 10)) + } + case reflect.Float32: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 32)) + } + case reflect.Float64: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, strconv.FormatFloat(v.Float(), 'f', -1, 64)) + } + default: + return func(key string, v reflect.Value, writer *multipart.Writer) error { + return fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + keyFn := e.arrayKeyEncoder() + if e.arrayFmt == "comma" { + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if v.Len() == 0 { + return nil + } + elements := make([]string, v.Len()) + for i := 0; i < v.Len(); i++ { + elements[i] = fmt.Sprint(v.Index(i).Interface()) + } + return writer.WriteField(key, strings.Join(elements, ",")) + } + } + return func(key string, v reflect.Value, writer *multipart.Writer) error { + if keyFn == nil { + return fmt.Errorf("apiform: unsupported array format") + } + for i := 0; i < v.Len(); i++ { + err := itemEncoder(keyFn(key, i), v.Index(i), writer) + if err != nil { + return err + } + } + return nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous { + return e.newStructUnionTypeEncoder(t) + } + } + + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseFormStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" || ptag.name == "" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + + var encoderFn encoderFunc + if ptag.omitzero { + typeEncoderFn := e.typeEncoder(field.Type) + encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error { + if value.IsZero() { + return nil + } + return typeEncoderFn(key, value, writer) + } + } else if ptag.defaultValue != nil { + typeEncoderFn := e.typeEncoder(field.Type) + encoderFn = func(key string, value reflect.Value, writer *multipart.Writer) error { + if value.IsZero() { + return typeEncoderFn(key, reflect.ValueOf(ptag.defaultValue), writer) + } + return typeEncoderFn(key, value, writer) + } + } else { + encoderFn = e.typeEncoder(field.Type) + } + encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + keyFn := e.objKeyEncoder(key) + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + err := ef.fn(keyFn(ef.tag.name), field, writer) + if err != nil { + return err + } + } + + if extraEncoder != nil { + err := e.encodeMapEntries(key, value.FieldByIndex(extraEncoder.idx), writer) + if err != nil { + return err + } + } + + return nil + } +} + +var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem() + +func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc { + var fieldEncoders []encoderFunc + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type == paramUnionType && field.Anonymous { + fieldEncoders = append(fieldEncoders, nil) + continue + } + fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type)) + } + + return func(key string, value reflect.Value, writer *multipart.Writer) error { + for i := 0; i < t.NumField(); i++ { + if value.Field(i).Type() == paramUnionType { + continue + } + if !value.Field(i).IsZero() { + return fieldEncoders[i](key, value.Field(i), writer) + } + } + return fmt.Errorf("apiform: union %s has no field set", t.String()) + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return writer.WriteField(key, value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format)) + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + value = value.Elem() + if !value.IsValid() { + return nil + } + return e.typeEncoder(value.Type())(key, value, writer) + } +} + +var quoteEscaper = strings.NewReplacer("\\", "\\\\", `"`, "\\\"") + +func escapeQuotes(s string) string { + return quoteEscaper.Replace(s) +} + +func (e *encoder) newReaderTypeEncoder() encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + reader, ok := value.Convert(reflect.TypeOf((*io.Reader)(nil)).Elem()).Interface().(io.Reader) + if !ok { + return nil + } + filename := "anonymous_file" + contentType := "application/octet-stream" + if named, ok := reader.(interface{ Filename() string }); ok { + filename = named.Filename() + } else if named, ok := reader.(interface{ Name() string }); ok { + filename = path.Base(named.Name()) + } + if typed, ok := reader.(interface{ ContentType() string }); ok { + contentType = typed.ContentType() + } + + // Below is taken almost 1-for-1 from [multipart.CreateFormFile] + h := make(textproto.MIMEHeader) + h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, escapeQuotes(key), escapeQuotes(filename))) + h.Set("Content-Type", contentType) + filewriter, err := writer.CreatePart(h) + if err != nil { + return err + } + _, err = io.Copy(filewriter, reader) + return err + } +} + +func (e encoder) arrayKeyEncoder() func(string, int) string { + var keyFn func(string, int) string + switch e.arrayFmt { + case "comma", "repeat": + keyFn = func(k string, _ int) string { return k } + case "brackets": + keyFn = func(key string, _ int) string { return key + "[]" } + case "indices:dots": + keyFn = func(k string, i int) string { + if k == "" { + return strconv.Itoa(i) + } + return k + "." + strconv.Itoa(i) + } + case "indices:brackets": + keyFn = func(k string, i int) string { + if k == "" { + return strconv.Itoa(i) + } + return k + "[" + strconv.Itoa(i) + "]" + } + } + return keyFn +} + +func (e encoder) objKeyEncoder(parent string) func(string) string { + if parent == "" { + return func(child string) string { return child } + } + switch e.arrayFmt { + case "brackets": + return func(child string) string { return parent + "[" + child + "]" } + default: + return func(child string) string { return parent + "." + child } + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipart.Writer) error { + type mapPair struct { + key string + value reflect.Value + } + + pairs := []mapPair{} + + iter := v.MapRange() + for iter.Next() { + if iter.Key().Type().Kind() == reflect.String { + pairs = append(pairs, mapPair{key: iter.Key().String(), value: iter.Value()}) + } else { + return fmt.Errorf("cannot encode a map with a non string key") + } + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].key < pairs[j].key + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + keyFn := e.objKeyEncoder(key) + for _, p := range pairs { + err := elementEncoder(keyFn(p.key), p.value, writer) + if err != nil { + return err + } + } + + return nil +} + +func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc { + return func(key string, value reflect.Value, writer *multipart.Writer) error { + return e.encodeMapEntries(key, value, writer) + } +} + +func WriteExtras(writer *multipart.Writer, extras map[string]any) (err error) { + for k, v := range extras { + str, ok := v.(string) + if !ok { + break + } + err = writer.WriteField(k, str) + if err != nil { + break + } + } + return err +} diff --git a/internal/apiform/form.go b/internal/apiform/form.go new file mode 100644 index 0000000..5445116 --- /dev/null +++ b/internal/apiform/form.go @@ -0,0 +1,5 @@ +package apiform + +type Marshaler interface { + MarshalMultipart() ([]byte, string, error) +} diff --git a/internal/apiform/richparam.go b/internal/apiform/richparam.go new file mode 100644 index 0000000..bbdac1a --- /dev/null +++ b/internal/apiform/richparam.go @@ -0,0 +1,20 @@ +package apiform + +import ( + "scalar-api/packages/param" + "mime/multipart" + "reflect" +) + +func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.newPrimitiveTypeEncoder(f.Type) + return func(key string, value reflect.Value, writer *multipart.Writer) error { + if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() { + return enc(key, value.FieldByIndex(f.Index), writer) + } else if ok && param.IsNull(opt) { + return writer.WriteField(key, "null") + } + return nil + } +} diff --git a/internal/apiform/tag.go b/internal/apiform/tag.go new file mode 100644 index 0000000..f0c9d14 --- /dev/null +++ b/internal/apiform/tag.go @@ -0,0 +1,88 @@ +package apiform + +import ( + "reflect" + "strings" +) + +const apiStructTag = "api" +const jsonStructTag = "json" +const formStructTag = "form" +const formatStructTag = "format" +const defaultStructTag = "default" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool + omitzero bool + defaultValue any +} + +func parseFormStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(formStructTag) + if !ok { + raw, ok = field.Tag.Lookup(jsonStructTag) + } + if !ok { + return tag, ok + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + case "omitzero": + tag.omitzero = true + } + } + + parseApiStructTag(field, &tag) + parseDefaultStructTag(field, &tag) + return tag, ok +} + +func parseDefaultStructTag(field reflect.StructField, tag *parsedStructTag) { + if field.Type.Kind() != reflect.String { + // Only strings are currently supported + return + } + + raw, ok := field.Tag.Lookup(defaultStructTag) + if !ok { + return + } + tag.defaultValue = raw +} + +func parseApiStructTag(field reflect.StructField, tag *parsedStructTag) { + raw, ok := field.Tag.Lookup(apiStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + for _, part := range parts { + switch part { + case "extrafields": + tag.extras = true + case "required": + tag.required = true + case "metadata": + tag.metadata = true + } + } +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return format, ok +} diff --git a/internal/apijson/decoder.go b/internal/apijson/decoder.go new file mode 100644 index 0000000..d507acd --- /dev/null +++ b/internal/apijson/decoder.go @@ -0,0 +1,688 @@ +// The deserialization algorithm from apijson may be subject to improvements +// between minor versions, particularly with respect to calling [json.Unmarshal] +// into param unions. + +package apijson + +import ( + "encoding/json" + "fmt" + "scalar-api/packages/param" + "reflect" + "strconv" + "sync" + "time" + "unsafe" + + "github.com/tidwall/gjson" +) + +// decoders is a synchronized map with roughly the following type: +// map[reflect.Type]decoderFunc +var decoders sync.Map + +// Unmarshal is similar to [encoding/json.Unmarshal] and parses the JSON-encoded +// data and stores it in the given pointer. +func Unmarshal(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339} + return d.unmarshal(raw, to) +} + +// UnmarshalRoot is like Unmarshal, but doesn't try to call MarshalJSON on the +// root element. Useful if a struct's UnmarshalJSON is overrode to use the +// behavior of this encoder versus the standard library. +func UnmarshalRoot(raw []byte, to any) error { + d := &decoderBuilder{dateFormat: time.RFC3339, root: true} + return d.unmarshal(raw, to) +} + +// decoderBuilder contains the 'compile-time' state of the decoder. +type decoderBuilder struct { + // Whether or not this is the first element and called by [UnmarshalRoot], see + // the documentation there to see why this is necessary. + root bool + // The dateFormat (a format string for [time.Format]) which is chosen by the + // last struct tag that was seen. + dateFormat string +} + +// decoderState contains the 'run-time' state of the decoder. +type decoderState struct { + strict bool + exactness exactness + validator *validationEntry +} + +// Exactness refers to how close to the type the result was if deserialization +// was successful. This is useful in deserializing unions, where you want to try +// each entry, first with strict, then with looser validation, without actually +// having to do a lot of redundant work by marshalling twice (or maybe even more +// times). +type exactness int8 + +const ( + // Some values had to fudged a bit, for example by converting a string to an + // int, or an enum with extra values. + loose exactness = iota + // There are some extra arguments, but other wise it matches the union. + extras + // Exactly right. + exact +) + +type decoderFunc func(node gjson.Result, value reflect.Value, state *decoderState) error + +type decoderField struct { + tag parsedStructTag + fn decoderFunc + idx []int + goname string +} + +type decoderEntry struct { + typ reflect.Type + dateFormat string + root bool +} + +func (d *decoderBuilder) unmarshal(raw []byte, to any) error { + value := reflect.ValueOf(to).Elem() + result := gjson.ParseBytes(raw) + if !value.IsValid() { + return fmt.Errorf("apijson: cannot marshal into invalid value") + } + return d.typeDecoder(value.Type())(result, value, &decoderState{strict: false, exactness: exact}) +} + +// unmarshalWithExactness is used for internal testing purposes. +func (d *decoderBuilder) unmarshalWithExactness(raw []byte, to any) (exactness, error) { + value := reflect.ValueOf(to).Elem() + result := gjson.ParseBytes(raw) + if !value.IsValid() { + return 0, fmt.Errorf("apijson: cannot marshal into invalid value") + } + state := decoderState{strict: false, exactness: exact} + err := d.typeDecoder(value.Type())(result, value, &state) + return state.exactness, err +} + +func (d *decoderBuilder) typeDecoder(t reflect.Type) decoderFunc { + entry := decoderEntry{ + typ: t, + dateFormat: d.dateFormat, + root: d.root, + } + + if fi, ok := decoders.Load(entry); ok { + return fi.(decoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f decoderFunc + ) + wg.Add(1) + fi, loaded := decoders.LoadOrStore(entry, decoderFunc(func(node gjson.Result, v reflect.Value, state *decoderState) error { + wg.Wait() + return f(node, v, state) + })) + if loaded { + return fi.(decoderFunc) + } + + // Compute the real decoder and replace the indirect func with it. + f = d.newTypeDecoder(t) + wg.Done() + decoders.Store(entry, f) + return f +} + +// validatedTypeDecoder wraps the type decoder with a validator. This is helpful +// for ensuring that enum fields are correct. +func (d *decoderBuilder) validatedTypeDecoder(t reflect.Type, entry *validationEntry) decoderFunc { + dec := d.typeDecoder(t) + if entry == nil { + return dec + } + + // Thread the current validation entry through the decoder, + // but clean up in time for the next field. + return func(node gjson.Result, v reflect.Value, state *decoderState) error { + state.validator = entry + err := dec(node, v, state) + state.validator = nil + return err + } +} + +func indirectUnmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func unmarshalerDecoder(n gjson.Result, v reflect.Value, state *decoderState) error { + if v.Kind() == reflect.Pointer && v.CanSet() { + v.Set(reflect.New(v.Type().Elem())) + } + return v.Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) +} + +func (d *decoderBuilder) newTypeDecoder(t reflect.Type) decoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return d.newTimeTypeDecoder(t) + } + + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return d.newOptTypeDecoder(t) + } + + if !d.root && t.Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + return unmarshalerDecoder + } + if !d.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Unmarshaler)(nil)).Elem()) { + if _, ok := unionVariants[t]; !ok { + return indirectUnmarshalerDecoder + } + } + d.root = false + + if _, ok := unionRegistry[t]; ok { + if isStructUnion(t) { + return d.newStructUnionDecoder(t) + } + return d.newUnionDecoder(t) + } + + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + innerDecoder := d.typeDecoder(inner) + + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + if !v.IsValid() { + return fmt.Errorf("apijson: unexpected invalid reflection value %+#v", v) + } + + newValue := reflect.New(inner).Elem() + err := innerDecoder(n, newValue, state) + if err != nil { + return err + } + + v.Set(newValue.Addr()) + return nil + } + case reflect.Struct: + if isStructUnion(t) { + return d.newStructUnionDecoder(t) + } + return d.newStructTypeDecoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return d.newArrayTypeDecoder(t) + case reflect.Map: + return d.newMapDecoder(t) + case reflect.Interface: + return func(node gjson.Result, value reflect.Value, state *decoderState) error { + if !value.IsValid() { + return fmt.Errorf("apijson: unexpected invalid value %+#v", value) + } + if node.Value() != nil && value.CanSet() { + value.Set(reflect.ValueOf(node.Value())) + } + return nil + } + default: + return d.newPrimitiveTypeDecoder(t) + } +} + +func (d *decoderBuilder) newMapDecoder(t reflect.Type) decoderFunc { + keyType := t.Key() + itemType := t.Elem() + itemDecoder := d.typeDecoder(itemType) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + mapValue := reflect.MakeMapWithSize(t, len(node.Map())) + + node.ForEach(func(key, value gjson.Result) bool { + // It's fine for us to just use `ValueOf` here because the key types will + // always be primitive types so we don't need to decode it using the standard pattern + keyValue := reflect.ValueOf(key.Value()) + if !keyValue.IsValid() { + if err == nil { + err = fmt.Errorf("apijson: received invalid key type %v", keyValue.String()) + } + return false + } + if keyValue.Type() != keyType { + if err == nil { + err = fmt.Errorf("apijson: expected key type %v but got %v", keyType, keyValue.Type()) + } + return false + } + + itemValue := reflect.New(itemType).Elem() + itemerr := itemDecoder(value, itemValue, state) + if itemerr != nil { + if err == nil { + err = itemerr + } + return false + } + + mapValue.SetMapIndex(keyValue, itemValue) + return true + }) + + if err != nil { + return err + } + value.Set(mapValue) + return nil + } +} + +func (d *decoderBuilder) newArrayTypeDecoder(t reflect.Type) decoderFunc { + itemDecoder := d.typeDecoder(t.Elem()) + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if !node.IsArray() { + return fmt.Errorf("apijson: could not deserialize to an array") + } + + arrayNode := node.Array() + + arrayValue := reflect.MakeSlice(reflect.SliceOf(t.Elem()), len(arrayNode), len(arrayNode)) + for i, itemNode := range arrayNode { + err = itemDecoder(itemNode, arrayValue.Index(i), state) + if err != nil { + return err + } + } + + value.Set(arrayValue) + return nil + } +} + +func (d *decoderBuilder) newStructTypeDecoder(t reflect.Type) decoderFunc { + // map of json field name to struct field decoders + decoderFields := map[string]decoderField{} + anonymousDecoders := []decoderField{} + extraDecoder := (*decoderField)(nil) + var inlineDecoders []decoderField + + validationEntries := validationRegistry[t] + + for i := 0; i < t.NumField(); i++ { + idx := []int{i} + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + + var validator *validationEntry + for _, entry := range validationEntries { + if entry.field.Offset == field.Offset { + validator = &entry + break + } + } + + // If this is an embedded struct, traverse one level deeper to extract + // the fields and get their encoders as well. + if field.Anonymous { + anonymousDecoders = append(anonymousDecoders, decoderField{ + fn: d.typeDecoder(field.Type), + idx: idx[:], + }) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported fields if they're tagged with + // `extras` because that field shouldn't be part of the public API. + if ptag.extras { + extraDecoder = &decoderField{ptag, d.typeDecoder(field.Type.Elem()), idx, field.Name} + continue + } + if ptag.inline { + df := decoderField{ptag, d.typeDecoder(field.Type), idx, field.Name} + inlineDecoders = append(inlineDecoders, df) + continue + } + if ptag.metadata { + continue + } + + oldFormat := d.dateFormat + dateFormat, ok := parseFormatStructTag(field) + if ok { + switch dateFormat { + case "date-time": + d.dateFormat = time.RFC3339 + case "date": + d.dateFormat = "2006-01-02" + } + } + + decoderFields[ptag.name] = decoderField{ + ptag, + d.validatedTypeDecoder(field.Type, validator), + idx, field.Name, + } + + d.dateFormat = oldFormat + } + + return func(node gjson.Result, value reflect.Value, state *decoderState) (err error) { + if field := value.FieldByName("JSON"); field.IsValid() { + if raw := field.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, node.Raw) + } + } + + for _, decoder := range anonymousDecoders { + // ignore errors + _ = decoder.fn(node, value.FieldByIndex(decoder.idx), state) + } + + for _, inlineDecoder := range inlineDecoders { + var meta Field + dest := value.FieldByIndex(inlineDecoder.idx) + isValid := false + if dest.IsValid() && node.Type != gjson.Null { + inlineState := decoderState{exactness: state.exactness, strict: true} + err = inlineDecoder.fn(node, dest, &inlineState) + if err == nil { + isValid = true + } + } + + if node.Type == gjson.Null { + meta = Field{ + raw: node.Raw, + status: null, + } + } else if !isValid { + // If an inline decoder fails, unset the field and move on. + if dest.IsValid() { + dest.SetZero() + } + continue + } else if isValid { + meta = Field{ + raw: node.Raw, + status: valid, + } + } + setMetadataSubField(value, inlineDecoder.idx, inlineDecoder.goname, meta) + } + + typedExtraType := reflect.Type(nil) + typedExtraFields := reflect.Value{} + if extraDecoder != nil { + typedExtraType = value.FieldByIndex(extraDecoder.idx).Type() + typedExtraFields = reflect.MakeMap(typedExtraType) + } + untypedExtraFields := map[string]Field{} + + for fieldName, itemNode := range node.Map() { + df, explicit := decoderFields[fieldName] + var ( + dest reflect.Value + fn decoderFunc + meta Field + ) + if explicit { + fn = df.fn + dest = value.FieldByIndex(df.idx) + } + if !explicit && extraDecoder != nil { + dest = reflect.New(typedExtraType.Elem()).Elem() + fn = extraDecoder.fn + } + + isValid := false + if dest.IsValid() && itemNode.Type != gjson.Null { + err = fn(itemNode, dest, state) + if err == nil { + isValid = true + } + } + + // Handle null [param.Opt] + if itemNode.Type == gjson.Null && dest.IsValid() && dest.Type().Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + _ = dest.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(itemNode.Raw)) + continue + } + + if itemNode.Type == gjson.Null { + meta = Field{ + raw: itemNode.Raw, + status: null, + } + } else if !isValid { + meta = Field{ + raw: itemNode.Raw, + status: invalid, + } + } else if isValid { + meta = Field{ + raw: itemNode.Raw, + status: valid, + } + } + + if explicit { + setMetadataSubField(value, df.idx, df.goname, meta) + } + if !explicit { + untypedExtraFields[fieldName] = meta + } + if !explicit && extraDecoder != nil { + typedExtraFields.SetMapIndex(reflect.ValueOf(fieldName), dest) + } + } + + if extraDecoder != nil && typedExtraFields.Len() > 0 { + value.FieldByIndex(extraDecoder.idx).Set(typedExtraFields) + } + + // Set exactness to 'extras' if there are untyped, extra fields. + if len(untypedExtraFields) > 0 && state.exactness > extras { + state.exactness = extras + } + + if len(untypedExtraFields) > 0 { + setMetadataExtraFields(value, []int{-1}, "ExtraFields", untypedExtraFields) + } + return nil + } +} + +func (d *decoderBuilder) newPrimitiveTypeDecoder(t reflect.Type) decoderFunc { + switch t.Kind() { + case reflect.String: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetString(n.String()) + if guardStrict(state, n.Type != gjson.String) { + return fmt.Errorf("apijson: failed to parse string strictly") + } + // Everything that is not an object can be loosely stringified. + if n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse string") + } + + state.validateString(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed string enum validation") + } + return nil + } + case reflect.Bool: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetBool(n.Bool()) + if guardStrict(state, n.Type != gjson.True && n.Type != gjson.False) { + return fmt.Errorf("apijson: failed to parse bool strictly") + } + // Numbers and strings that are either 'true' or 'false' can be loosely + // deserialized as bool. + if n.Type == gjson.String && (n.Raw != "true" && n.Raw != "false") || n.Type == gjson.JSON { + return fmt.Errorf("apijson: failed to parse bool") + } + + state.validateBool(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed bool enum validation") + } + return nil + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetInt(n.Int()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num))) { + return fmt.Errorf("apijson: failed to parse int strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as numbers. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse int") + } + + state.validateInt(v) + + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed int enum validation") + } + return nil + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetUint(n.Uint()) + if guardStrict(state, n.Type != gjson.Number || n.Num != float64(int(n.Num)) || n.Num < 0) { + return fmt.Errorf("apijson: failed to parse uint strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as uint. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse uint") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed uint enum validation") + } + return nil + } + case reflect.Float32, reflect.Float64: + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + v.SetFloat(n.Float()) + if guardStrict(state, n.Type != gjson.Number) { + return fmt.Errorf("apijson: failed to parse float strictly") + } + // Numbers, booleans, and strings that maybe look like numbers can be + // loosely deserialized as floats. + if n.Type == gjson.JSON || (n.Type == gjson.String && !canParseAsNumber(n.Str)) { + return fmt.Errorf("apijson: failed to parse float") + } + if guardUnknown(state, v) { + return fmt.Errorf("apijson: failed float enum validation") + } + return nil + } + default: + return func(node gjson.Result, v reflect.Value, state *decoderState) error { + return fmt.Errorf("unknown type received at primitive decoder: %s", t.String()) + } + } +} + +func (d *decoderBuilder) newOptTypeDecoder(t reflect.Type) decoderFunc { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + valueField, _ := t.FieldByName("Value") + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + state.validateOptKind(n, valueField.Type) + return v.Addr().Interface().(json.Unmarshaler).UnmarshalJSON([]byte(n.Raw)) + } +} + +func (d *decoderBuilder) newTimeTypeDecoder(t reflect.Type) decoderFunc { + format := d.dateFormat + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + parsed, err := time.Parse(format, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + + if guardStrict(state, true) { + return err + } + + layouts := []string{ + "2006-01-02", + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05Z0700", + "2006-01-02T15:04:05", + "2006-01-02 15:04:05Z07:00", + "2006-01-02 15:04:05Z0700", + "2006-01-02 15:04:05", + } + + for _, layout := range layouts { + parsed, err := time.Parse(layout, n.Str) + if err == nil { + v.Set(reflect.ValueOf(parsed).Convert(t)) + return nil + } + } + + return fmt.Errorf("unable to leniently parse date-time string: %s", n.Str) + } +} + +func setUnexportedField(field reflect.Value, value any) { + reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Set(reflect.ValueOf(value)) +} + +func guardStrict(state *decoderState, cond bool) bool { + if !cond { + return false + } + + if state.strict { + return true + } + + state.exactness = loose + return false +} + +func canParseAsNumber(str string) bool { + _, err := strconv.ParseFloat(str, 64) + return err == nil +} + +var stringType = reflect.TypeOf(string("")) + +func guardUnknown(state *decoderState, v reflect.Value) bool { + if have, ok := v.Interface().(interface{ IsKnown() bool }); guardStrict(state, ok && !have.IsKnown()) { + return true + } + + constantString, ok := v.Interface().(interface{ Default() string }) + named := v.Type() != stringType + return guardStrict(state, ok && named && v.Equal(reflect.ValueOf(constantString.Default()))) +} diff --git a/internal/apijson/encoder.go b/internal/apijson/encoder.go new file mode 100644 index 0000000..66f26b7 --- /dev/null +++ b/internal/apijson/encoder.go @@ -0,0 +1,379 @@ +package apijson + +import ( + "bytes" + "encoding/json" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/tidwall/sjson" + + shimjson "scalar-api/internal/encoding/json" +) + +var encoders sync.Map // map[encoderEntry]encoderFunc + +// If we want to set a literal key value into JSON using sjson, we need to make sure it doesn't have +// special characters that sjson interprets as a path. +var EscapeSJSONKey = strings.NewReplacer("\\", "\\\\", "|", "\\|", "#", "\\#", "@", "\\@", "*", "\\*", ".", "\\.", ":", "\\:", "?", "\\?").Replace + +func Marshal(value any) ([]byte, error) { + e := &encoder{dateFormat: time.RFC3339} + return e.marshal(value) +} + +func MarshalRoot(value any) ([]byte, error) { + e := &encoder{root: true, dateFormat: time.RFC3339} + return e.marshal(value) +} + +type encoder struct { + dateFormat string + root bool +} + +type encoderFunc func(value reflect.Value) ([]byte, error) + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + typ reflect.Type + dateFormat string + root bool +} + +func (e *encoder) marshal(value any) ([]byte, error) { + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil, nil + } + typ := val.Type() + enc := e.typeEncoder(typ) + return enc(val) +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + typ: t, + dateFormat: e.dateFormat, + root: e.root, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(v reflect.Value) ([]byte, error) { + wg.Wait() + return f(v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Interface().(json.Marshaler).MarshalJSON() +} + +func indirectMarshalerEncoder(v reflect.Value) ([]byte, error) { + return v.Addr().Interface().(json.Marshaler).MarshalJSON() +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder() + } + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + if !e.root && reflect.PointerTo(t).Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return indirectMarshalerEncoder + } + e.root = false + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.typeEncoder(inner) + return func(v reflect.Value) ([]byte, error) { + if !v.IsValid() || v.IsNil() { + return nil, nil + } + return innerEncoder(v.Elem()) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + // Note that we could use `gjson` to encode these types but it would complicate our + // code more and this current code shouldn't cause any issues + case reflect.String: + return func(v reflect.Value) ([]byte, error) { + return json.Marshal(v.Interface()) + } + case reflect.Bool: + return func(v reflect.Value) ([]byte, error) { + if v.Bool() { + return []byte("true"), nil + } + return []byte("false"), nil + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatInt(v.Int(), 10)), nil + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatUint(v.Uint(), 10)), nil + } + case reflect.Float32: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 32)), nil + } + case reflect.Float64: + return func(v reflect.Value) ([]byte, error) { + return []byte(strconv.FormatFloat(v.Float(), 'f', -1, 64)), nil + } + default: + return func(v reflect.Value) ([]byte, error) { + return nil, fmt.Errorf("unknown type received at primitive encoder: %s", t.String()) + } + } +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + itemEncoder := e.typeEncoder(t.Elem()) + + return func(value reflect.Value) ([]byte, error) { + json := []byte("[]") + for i := 0; i < value.Len(); i++ { + var value, err = itemEncoder(value.Index(i)) + if err != nil { + return nil, err + } + if value == nil { + // Assume that empty items should be inserted as `null` so that the output array + // will be the same length as the input array + value = []byte("null") + } + + json, err = sjson.SetRawBytes(json, "-1", value) + if err != nil { + return nil, err + } + } + + return json, nil + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + encoderFields := []encoderField{} + extraEncoder := (*encoderField)(nil) + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If json tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + // We only want to support unexported field if they're tagged with + // `extras` because that field shouldn't be part of the public API. We + // also want to only keep the top level extras + if ptag.extras && len(index) == 0 { + extraEncoder = &encoderField{ptag, e.typeEncoder(field.Type.Elem()), idx} + continue + } + if ptag.name == "-" { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + encoderFields = append(encoderFields, encoderField{ptag, e.typeEncoder(field.Type), idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + // Ensure deterministic output by sorting by lexicographic order + sort.Slice(encoderFields, func(i, j int) bool { + return encoderFields[i].tag.name < encoderFields[j].tag.name + }) + + return func(value reflect.Value) (json []byte, err error) { + json = []byte("{}") + + for _, ef := range encoderFields { + field := value.FieldByIndex(ef.idx) + encoded, err := ef.fn(field) + if err != nil { + return nil, err + } + if ef.tag.defaultValue != nil && (!field.IsValid() || field.IsZero()) { + encoded, err = shimjson.Marshal(ef.tag.defaultValue) + if err != nil { + return nil, err + } + } + if encoded == nil { + continue + } + json, err = sjson.SetRawBytes(json, EscapeSJSONKey(ef.tag.name), encoded) + if err != nil { + return nil, err + } + } + + if extraEncoder != nil { + json, err = e.encodeMapEntries(json, value.FieldByIndex(extraEncoder.idx)) + if err != nil { + return nil, err + } + } + return json, err + } +} + +func (e *encoder) newTimeTypeEncoder() encoderFunc { + format := e.dateFormat + return func(value reflect.Value) (json []byte, err error) { + return []byte(`"` + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format) + `"`), nil + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(value reflect.Value) ([]byte, error) { + value = value.Elem() + if !value.IsValid() { + return nil, nil + } + return e.typeEncoder(value.Type())(value) + } +} + +// Given a []byte of json (may either be an empty object or an object that already contains entries) +// encode all of the entries in the map to the json byte array. +func (e *encoder) encodeMapEntries(json []byte, v reflect.Value) ([]byte, error) { + type mapPair struct { + key []byte + value reflect.Value + } + + pairs := []mapPair{} + keyEncoder := e.typeEncoder(v.Type().Key()) + + iter := v.MapRange() + for iter.Next() { + var encodedKeyString string + if iter.Key().Type().Kind() == reflect.String { + encodedKeyString = iter.Key().String() + } else { + var err error + encodedKeyBytes, err := keyEncoder(iter.Key()) + if err != nil { + return nil, err + } + encodedKeyString = string(encodedKeyBytes) + } + encodedKey := []byte(encodedKeyString) + pairs = append(pairs, mapPair{key: encodedKey, value: iter.Value()}) + } + + // Ensure deterministic output + sort.Slice(pairs, func(i, j int) bool { + return bytes.Compare(pairs[i].key, pairs[j].key) < 0 + }) + + elementEncoder := e.typeEncoder(v.Type().Elem()) + for _, p := range pairs { + encodedValue, err := elementEncoder(p.value) + if err != nil { + return nil, err + } + if len(encodedValue) == 0 { + continue + } + json, err = sjson.SetRawBytes(json, EscapeSJSONKey(string(p.key)), encodedValue) + if err != nil { + return nil, err + } + } + + return json, nil +} + +func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc { + return func(value reflect.Value) ([]byte, error) { + json := []byte("{}") + var err error + json, err = e.encodeMapEntries(json, value) + if err != nil { + return nil, err + } + return json, nil + } +} diff --git a/internal/apijson/enum.go b/internal/apijson/enum.go new file mode 100644 index 0000000..a1626a5 --- /dev/null +++ b/internal/apijson/enum.go @@ -0,0 +1,140 @@ +package apijson + +import ( + "fmt" + "reflect" + "slices" + + "github.com/tidwall/gjson" +) + +/********************/ +/* Validating Enums */ +/********************/ + +type validationEntry struct { + field reflect.StructField + legalValues struct { + strings []string + // 1 represents true, 0 represents false, -1 represents either + bools int + ints []int64 + } +} + +var validationRegistry = map[reflect.Type][]validationEntry{} + +func RegisterFieldValidator[T any, V string | bool | int | float64](fieldName string, values ...V) { + var t T + parentType := reflect.TypeOf(t) + + if _, ok := validationRegistry[parentType]; !ok { + validationRegistry[parentType] = []validationEntry{} + } + + // The following checks run at initialization time, + // it is impossible for them to panic if any tests pass. + if parentType.Kind() != reflect.Struct { + panic(fmt.Sprintf("apijson: cannot initialize validator for non-struct %s", parentType.String())) + } + + var field reflect.StructField + found := false + for i := 0; i < parentType.NumField(); i++ { + ptag, ok := parseJSONStructTag(parentType.Field(i)) + if ok && ptag.name == fieldName { + field = parentType.Field(i) + found = true + break + } + } + + if !found { + panic(fmt.Sprintf("apijson: cannot find field %s in struct %s", fieldName, parentType.String())) + } + + newEntry := validationEntry{field: field} + newEntry.legalValues.bools = -1 // default to either + + switch values := any(values).(type) { + case []string: + newEntry.legalValues.strings = values + case []int: + newEntry.legalValues.ints = make([]int64, len(values)) + for i, value := range values { + newEntry.legalValues.ints[i] = int64(value) + } + case []bool: + for i, value := range values { + var next int + if value { + next = 1 + } + if i > 0 && newEntry.legalValues.bools != next { + newEntry.legalValues.bools = -1 // accept either + break + } + newEntry.legalValues.bools = next + } + } + + // Store the information necessary to create a validator, so that we can use it + // lazily create the validator function when did. + validationRegistry[parentType] = append(validationRegistry[parentType], newEntry) +} + +func (state *decoderState) validateString(v reflect.Value) { + if state.validator == nil { + return + } + if !slices.Contains(state.validator.legalValues.strings, v.String()) { + state.exactness = loose + } +} + +func (state *decoderState) validateInt(v reflect.Value) { + if state.validator == nil { + return + } + if !slices.Contains(state.validator.legalValues.ints, v.Int()) { + state.exactness = loose + } +} + +func (state *decoderState) validateBool(v reflect.Value) { + if state.validator == nil { + return + } + b := v.Bool() + if state.validator.legalValues.bools == 1 && !b { + state.exactness = loose + } else if state.validator.legalValues.bools == 0 && b { + state.exactness = loose + } +} + +func (state *decoderState) validateOptKind(node gjson.Result, t reflect.Type) { + switch node.Type { + case gjson.JSON: + state.exactness = loose + case gjson.Null: + return + case gjson.False, gjson.True: + if t.Kind() != reflect.Bool { + state.exactness = loose + } + case gjson.Number: + switch t.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64: + return + default: + state.exactness = loose + } + case gjson.String: + if t.Kind() != reflect.String { + state.exactness = loose + } + } +} diff --git a/internal/apijson/field.go b/internal/apijson/field.go new file mode 100644 index 0000000..854d6dd --- /dev/null +++ b/internal/apijson/field.go @@ -0,0 +1,23 @@ +package apijson + +type status uint8 + +const ( + missing status = iota + null + invalid + valid +) + +type Field struct { + raw string + status status +} + +// Returns true if the field is explicitly `null` _or_ if it is not present at all (ie, missing). +// To check if the field's key is present in the JSON with an explicit null value, +// you must check `f.IsNull() && !f.IsMissing()`. +func (j Field) IsNull() bool { return j.status <= null } +func (j Field) IsMissing() bool { return j.status == missing } +func (j Field) IsInvalid() bool { return j.status == invalid } +func (j Field) Raw() string { return j.raw } diff --git a/internal/apijson/port.go b/internal/apijson/port.go new file mode 100644 index 0000000..b40013c --- /dev/null +++ b/internal/apijson/port.go @@ -0,0 +1,120 @@ +package apijson + +import ( + "fmt" + "reflect" +) + +// Port copies over values from one struct to another struct. +func Port(from any, to any) error { + toVal := reflect.ValueOf(to) + fromVal := reflect.ValueOf(from) + + if toVal.Kind() != reflect.Ptr || toVal.IsNil() { + return fmt.Errorf("destination must be a non-nil pointer") + } + + for toVal.Kind() == reflect.Ptr { + toVal = toVal.Elem() + } + toType := toVal.Type() + + for fromVal.Kind() == reflect.Ptr { + fromVal = fromVal.Elem() + } + fromType := fromVal.Type() + + if toType.Kind() != reflect.Struct { + return fmt.Errorf("destination must be a non-nil pointer to a struct (%v %v)", toType, toType.Kind()) + } + + values := map[string]reflect.Value{} + fields := map[string]reflect.Value{} + + fromJSON := fromVal.FieldByName("JSON") + toJSON := toVal.FieldByName("JSON") + + // Iterate through the fields of v and load all the "normal" fields in the struct to the map of + // string to reflect.Value, as well as their raw .JSON.Foo counterpart indicated by j. + var getFields func(t reflect.Type, v reflect.Value) + getFields = func(t reflect.Type, v reflect.Value) { + j := v.FieldByName("JSON") + + // Recurse into anonymous fields first, since the fields on the object should win over the fields in the + // embedded object. + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Anonymous { + getFields(field.Type, v.Field(i)) + continue + } + } + + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok || ptag.name == "-" || ptag.name == "" { + continue + } + values[ptag.name] = v.Field(i) + if j.IsValid() { + fields[ptag.name] = j.FieldByName(field.Name) + } + } + } + getFields(fromType, fromVal) + + // Use the values from the previous step to populate the 'to' struct. + for i := 0; i < toType.NumField(); i++ { + field := toType.Field(i) + ptag, ok := parseJSONStructTag(field) + if !ok { + continue + } + if ptag.name == "-" { + continue + } + if value, ok := values[ptag.name]; ok { + delete(values, ptag.name) + if field.Type.Kind() == reflect.Interface { + toVal.Field(i).Set(value) + } else { + switch value.Kind() { + case reflect.String: + toVal.Field(i).SetString(value.String()) + case reflect.Bool: + toVal.Field(i).SetBool(value.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + toVal.Field(i).SetInt(value.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + toVal.Field(i).SetUint(value.Uint()) + case reflect.Float32, reflect.Float64: + toVal.Field(i).SetFloat(value.Float()) + default: + toVal.Field(i).Set(value) + } + } + } + + if fromJSONField, ok := fields[ptag.name]; ok { + if toJSONField := toJSON.FieldByName(field.Name); toJSONField.IsValid() { + toJSONField.Set(fromJSONField) + } + } + } + + // Finally, copy over the .JSON.raw and .JSON.ExtraFields + if toJSON.IsValid() { + if raw := toJSON.FieldByName("raw"); raw.IsValid() { + setUnexportedField(raw, fromJSON.Interface().(interface{ RawJSON() string }).RawJSON()) + } + + if toExtraFields := toJSON.FieldByName("ExtraFields"); toExtraFields.IsValid() { + if fromExtraFields := fromJSON.FieldByName("ExtraFields"); fromExtraFields.IsValid() { + setUnexportedField(toExtraFields, fromExtraFields.Interface()) + } + } + } + + return nil +} diff --git a/internal/apijson/registry.go b/internal/apijson/registry.go new file mode 100644 index 0000000..2a24982 --- /dev/null +++ b/internal/apijson/registry.go @@ -0,0 +1,51 @@ +package apijson + +import ( + "reflect" + + "github.com/tidwall/gjson" +) + +type UnionVariant struct { + TypeFilter gjson.Type + DiscriminatorValue any + Type reflect.Type +} + +var unionRegistry = map[reflect.Type]unionEntry{} +var unionVariants = map[reflect.Type]any{} + +type unionEntry struct { + discriminatorKey string + variants []UnionVariant +} + +func Discriminator[T any](value any) UnionVariant { + var zero T + return UnionVariant{ + TypeFilter: gjson.JSON, + DiscriminatorValue: value, + Type: reflect.TypeOf(zero), + } +} + +func RegisterUnion[T any](discriminator string, variants ...UnionVariant) { + typ := reflect.TypeOf((*T)(nil)).Elem() + unionRegistry[typ] = unionEntry{ + discriminatorKey: discriminator, + variants: variants, + } + for _, variant := range variants { + unionVariants[variant.Type] = typ + } +} + +// Useful to wrap a union type to force it to use [apijson.UnmarshalJSON] since you cannot define an +// UnmarshalJSON function on the interface itself. +type UnionUnmarshaler[T any] struct { + Value T +} + +func (c *UnionUnmarshaler[T]) UnmarshalJSON(buf []byte) error { + return UnmarshalRoot(buf, &c.Value) +} diff --git a/internal/apijson/subfield.go b/internal/apijson/subfield.go new file mode 100644 index 0000000..522ab3e --- /dev/null +++ b/internal/apijson/subfield.go @@ -0,0 +1,67 @@ +package apijson + +import ( + "scalar-api/packages/respjson" + "reflect" +) + +func getSubField(root reflect.Value, index []int, name string) reflect.Value { + strct := root.FieldByIndex(index[:len(index)-1]) + if !strct.IsValid() { + panic("couldn't find encapsulating struct for field " + name) + } + meta := strct.FieldByName("JSON") + if !meta.IsValid() { + return reflect.Value{} + } + field := meta.FieldByName(name) + if !field.IsValid() { + return reflect.Value{} + } + return field +} + +func setMetadataSubField(root reflect.Value, index []int, name string, meta Field) { + target := getSubField(root, index, name) + if !target.IsValid() { + return + } + + if target.Type() == reflect.TypeOf(meta) { + target.Set(reflect.ValueOf(meta)) + } else if respMeta := meta.toRespField(); target.Type() == reflect.TypeOf(respMeta) { + target.Set(reflect.ValueOf(respMeta)) + } +} + +func setMetadataExtraFields(root reflect.Value, index []int, name string, metaExtras map[string]Field) { + target := getSubField(root, index, name) + if !target.IsValid() { + return + } + + if target.Type() == reflect.TypeOf(metaExtras) { + target.Set(reflect.ValueOf(metaExtras)) + return + } + + newMap := make(map[string]respjson.Field, len(metaExtras)) + if target.Type() == reflect.TypeOf(newMap) { + for k, v := range metaExtras { + newMap[k] = v.toRespField() + } + target.Set(reflect.ValueOf(newMap)) + } +} + +func (f Field) toRespField() respjson.Field { + if f.IsMissing() { + return respjson.Field{} + } else if f.IsNull() { + return respjson.NewField("null") + } else if f.IsInvalid() { + return respjson.NewInvalidField(f.raw) + } else { + return respjson.NewField(f.raw) + } +} diff --git a/internal/apijson/tag.go b/internal/apijson/tag.go new file mode 100644 index 0000000..efcaf8c --- /dev/null +++ b/internal/apijson/tag.go @@ -0,0 +1,85 @@ +package apijson + +import ( + "reflect" + "strings" +) + +const apiStructTag = "api" +const jsonStructTag = "json" +const formatStructTag = "format" +const defaultStructTag = "default" + +type parsedStructTag struct { + name string + required bool + extras bool + metadata bool + inline bool + defaultValue any +} + +func parseJSONStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(jsonStructTag) + if !ok { + return tag, ok + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "required": + tag.required = true + case "extras": + tag.extras = true + case "metadata": + tag.metadata = true + case "inline": + tag.inline = true + } + } + + // the `api` struct tag is only used alongside `json` for custom behaviour + parseApiStructTag(field, &tag) + parseDefaultStructTag(field, &tag) + return tag, ok +} + +func parseDefaultStructTag(field reflect.StructField, tag *parsedStructTag) { + if field.Type.Kind() != reflect.String { + // Only strings are currently supported + return + } + + raw, ok := field.Tag.Lookup(defaultStructTag) + if !ok { + return + } + tag.defaultValue = raw +} + +func parseApiStructTag(field reflect.StructField, tag *parsedStructTag) { + raw, ok := field.Tag.Lookup(apiStructTag) + if !ok { + return + } + parts := strings.Split(raw, ",") + for _, part := range parts { + switch part { + case "extrafields": + tag.extras = true + case "required": + tag.required = true + case "metadata": + tag.metadata = true + } + } +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return format, ok +} diff --git a/internal/apijson/union.go b/internal/apijson/union.go new file mode 100644 index 0000000..2181e3f --- /dev/null +++ b/internal/apijson/union.go @@ -0,0 +1,208 @@ +package apijson + +import ( + "errors" + "scalar-api/packages/param" + "reflect" + + "github.com/tidwall/gjson" +) + +var apiUnionType = reflect.TypeOf(param.APIUnion{}) + +func isStructUnion(t reflect.Type) bool { + if t.Kind() != reflect.Struct { + return false + } + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == apiUnionType && t.Field(i).Anonymous { + return true + } + } + return false +} + +func RegisterDiscriminatedUnion[T any](key string, mappings map[string]reflect.Type) { + var t T + entry := unionEntry{ + discriminatorKey: key, + variants: []UnionVariant{}, + } + for k, typ := range mappings { + entry.variants = append(entry.variants, UnionVariant{ + DiscriminatorValue: k, + Type: typ, + }) + } + unionRegistry[reflect.TypeOf(t)] = entry +} + +func (d *decoderBuilder) newStructUnionDecoder(t reflect.Type) decoderFunc { + type variantDecoder struct { + decoder decoderFunc + field reflect.StructField + } + decoders := []variantDecoder{} + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + + if field.Anonymous && field.Type == apiUnionType { + continue + } + + decoder := d.typeDecoder(field.Type) + decoders = append(decoders, variantDecoder{ + decoder: decoder, + field: field, + }) + } + + type discriminatedDecoder struct { + variantDecoder + discriminator any + } + discriminatedDecoders := []discriminatedDecoder{} + unionEntry, discriminated := unionRegistry[t] + for _, variant := range unionEntry.variants { + // For each union variant, find a matching decoder and save it + for _, decoder := range decoders { + if decoder.field.Type.Elem() == variant.Type { + discriminatedDecoders = append(discriminatedDecoders, discriminatedDecoder{ + decoder, + variant.DiscriminatorValue, + }) + break + } + } + } + + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + if discriminated && n.Type == gjson.JSON && len(unionEntry.discriminatorKey) != 0 { + discriminator := n.Get(EscapeSJSONKey(unionEntry.discriminatorKey)).Value() + for _, decoder := range discriminatedDecoders { + if discriminator == decoder.discriminator { + inner := v.FieldByIndex(decoder.field.Index) + return decoder.decoder(n, inner, state) + } + } + return errors.New("apijson: was not able to find discriminated union variant") + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + bestVariant := -1 + for i, decoder := range decoders { + // Pointers are used to discern JSON object variants from value variants + if n.Type != gjson.JSON && decoder.field.Type.Kind() == reflect.Ptr { + continue + } + + sub := decoderState{strict: state.strict, exactness: exact} + inner := v.FieldByIndex(decoder.field.Index) + err := decoder.decoder(n, inner, &sub) + if err != nil { + continue + } + if sub.exactness == exact { + bestExactness = exact + bestVariant = i + break + } + if sub.exactness > bestExactness { + bestExactness = sub.exactness + bestVariant = i + } + } + + if bestExactness < loose { + return errors.New("apijson: was not able to coerce type as union") + } + + if guardStrict(state, bestExactness != exact) { + return errors.New("apijson: was not able to coerce type as union strictly") + } + + for i := 0; i < len(decoders); i++ { + if i == bestVariant { + continue + } + v.FieldByIndex(decoders[i].field.Index).SetZero() + } + + return nil + } +} + +// newUnionDecoder returns a decoderFunc that deserializes into a union using an +// algorithm roughly similar to Pydantic's [smart algorithm]. +// +// Conceptually this is equivalent to choosing the best schema based on how 'exact' +// the deserialization is for each of the schemas. +// +// If there is a tie in the level of exactness, then the tie is broken +// left-to-right. +// +// [smart algorithm]: https://docs.pydantic.dev/latest/concepts/unions/#smart-mode +func (d *decoderBuilder) newUnionDecoder(t reflect.Type) decoderFunc { + unionEntry, ok := unionRegistry[t] + if !ok { + panic("apijson: couldn't find union of type " + t.String() + " in union registry") + } + decoders := []decoderFunc{} + for _, variant := range unionEntry.variants { + decoder := d.typeDecoder(variant.Type) + decoders = append(decoders, decoder) + } + return func(n gjson.Result, v reflect.Value, state *decoderState) error { + // If there is a discriminator match, circumvent the exactness logic entirely + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + + if len(unionEntry.discriminatorKey) != 0 { + discriminatorValue := n.Get(EscapeSJSONKey(unionEntry.discriminatorKey)).Value() + if discriminatorValue == variant.DiscriminatorValue { + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, state) + v.Set(inner) + return err + } + } + } + + // Set bestExactness to worse than loose + bestExactness := loose - 1 + for idx, variant := range unionEntry.variants { + decoder := decoders[idx] + if variant.TypeFilter != n.Type { + continue + } + sub := decoderState{strict: state.strict, exactness: exact} + inner := reflect.New(variant.Type).Elem() + err := decoder(n, inner, &sub) + if err != nil { + continue + } + if sub.exactness == exact { + v.Set(inner) + return nil + } + if sub.exactness > bestExactness { + v.Set(inner) + bestExactness = sub.exactness + } + } + + if bestExactness < loose { + return errors.New("apijson: was not able to coerce type as union") + } + + if guardStrict(state, bestExactness != exact) { + return errors.New("apijson: was not able to coerce type as union strictly") + } + + return nil + } +} diff --git a/internal/apiquery/encoder.go b/internal/apiquery/encoder.go new file mode 100644 index 0000000..b72102b --- /dev/null +++ b/internal/apiquery/encoder.go @@ -0,0 +1,394 @@ +package apiquery + +import ( + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "time" + + "scalar-api/packages/param" +) + +var encoders sync.Map // map[reflect.Type]encoderFunc + +type encoder struct { + dateFormat string + root bool + settings QuerySettings +} + +type encoderFunc func(key string, value reflect.Value) ([]Pair, error) + +type encoderField struct { + tag parsedStructTag + fn encoderFunc + idx []int +} + +type encoderEntry struct { + typ reflect.Type + dateFormat string + root bool + settings QuerySettings +} + +type Pair struct { + key string + value string +} + +func (e *encoder) typeEncoder(t reflect.Type) encoderFunc { + entry := encoderEntry{ + typ: t, + dateFormat: e.dateFormat, + root: e.root, + settings: e.settings, + } + + if fi, ok := encoders.Load(entry); ok { + return fi.(encoderFunc) + } + + // To deal with recursive types, populate the map with an + // indirect func before we build it. This type waits on the + // real func (f) to be ready and then calls it. This indirect + // func is only used for recursive types. + var ( + wg sync.WaitGroup + f encoderFunc + ) + wg.Add(1) + fi, loaded := encoders.LoadOrStore(entry, encoderFunc(func(key string, v reflect.Value) ([]Pair, error) { + wg.Wait() + return f(key, v) + })) + if loaded { + return fi.(encoderFunc) + } + + // Compute the real encoder and replace the indirect func with it. + f = e.newTypeEncoder(t) + wg.Done() + encoders.Store(entry, f) + return f +} + +func marshalerEncoder(key string, value reflect.Value) ([]Pair, error) { + s, err := value.Interface().(json.Marshaler).MarshalJSON() + if err != nil { + return nil, fmt.Errorf("apiquery: json fallback marshal error %s", err) + } + return []Pair{{key, string(s)}}, nil +} + +func (e *encoder) newTypeEncoder(t reflect.Type) encoderFunc { + if t.ConvertibleTo(reflect.TypeOf(time.Time{})) { + return e.newTimeTypeEncoder(t) + } + + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + if !e.root && t.Implements(reflect.TypeOf((*json.Marshaler)(nil)).Elem()) { + return marshalerEncoder + } + + e.root = false + switch t.Kind() { + case reflect.Pointer: + encoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + if !value.IsValid() || value.IsNil() { + return pairs, err + } + return encoder(key, value.Elem()) + } + case reflect.Struct: + return e.newStructTypeEncoder(t) + case reflect.Array: + fallthrough + case reflect.Slice: + return e.newArrayTypeEncoder(t) + case reflect.Map: + return e.newMapEncoder(t) + case reflect.Interface: + return e.newInterfaceEncoder() + default: + return e.newPrimitiveTypeEncoder(t) + } +} + +func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc { + if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) { + return e.newRichFieldTypeEncoder(t) + } + + for i := 0; i < t.NumField(); i++ { + if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous { + return e.newStructUnionTypeEncoder(t) + } + } + + encoderFields := []encoderField{} + + // This helper allows us to recursively collect field encoders into a flat + // array. The parameter `index` keeps track of the access patterns necessary + // to get to some field. + var collectEncoderFields func(r reflect.Type, index []int) + collectEncoderFields = func(r reflect.Type, index []int) { + for i := 0; i < r.NumField(); i++ { + idx := append(index, i) + field := t.FieldByIndex(idx) + if !field.IsExported() { + continue + } + // If this is an embedded struct, traverse one level deeper to extract + // the field and get their encoders as well. + if field.Anonymous { + collectEncoderFields(field.Type, idx) + continue + } + // If query tag is not present, then we skip, which is intentionally + // different behavior from the stdlib. + ptag, ok := parseQueryStructTag(field) + if !ok { + continue + } + + if (ptag.name == "-" || ptag.name == "") && !ptag.inline { + continue + } + + dateFormat, ok := parseFormatStructTag(field) + oldFormat := e.dateFormat + if ok { + switch dateFormat { + case "date-time": + e.dateFormat = time.RFC3339 + case "date": + e.dateFormat = "2006-01-02" + } + } + var encoderFn encoderFunc + if ptag.omitzero { + typeEncoderFn := e.typeEncoder(field.Type) + encoderFn = func(key string, value reflect.Value) ([]Pair, error) { + if value.IsZero() { + return nil, nil + } + return typeEncoderFn(key, value) + } + } else { + encoderFn = e.typeEncoder(field.Type) + } + encoderFields = append(encoderFields, encoderField{ptag, encoderFn, idx}) + e.dateFormat = oldFormat + } + } + collectEncoderFields(t, []int{}) + + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for _, ef := range encoderFields { + subkey := e.renderKeyPath(key, ef.tag.name) + if ef.tag.inline { + subkey = key + } + + field := value.FieldByIndex(ef.idx) + subpairs, suberr := ef.fn(subkey, field) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return pairs, err + } +} + +var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem() + +func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc { + var fieldEncoders []encoderFunc + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + if field.Type == paramUnionType && field.Anonymous { + fieldEncoders = append(fieldEncoders, nil) + continue + } + fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type)) + } + + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for i := 0; i < t.NumField(); i++ { + if value.Field(i).Type() == paramUnionType { + continue + } + if !value.Field(i).IsZero() { + return fieldEncoders[i](key, value.Field(i)) + } + } + return nil, fmt.Errorf("apiquery: union %s has no field set", t.String()) + } +} + +func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc { + keyEncoder := e.typeEncoder(t.Key()) + elementEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + iter := value.MapRange() + for iter.Next() { + encodedKey, err := keyEncoder("", iter.Key()) + if err != nil { + return nil, err + } + if len(encodedKey) != 1 { + return nil, fmt.Errorf("apiquery: unexpected number of parts for encoded map key, map may contain non-primitive") + } + subkey := encodedKey[0].value + keyPath := e.renderKeyPath(key, subkey) + subpairs, suberr := elementEncoder(keyPath, iter.Value()) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return pairs, err + } +} + +func (e *encoder) renderKeyPath(key string, subkey string) string { + if len(key) == 0 { + return subkey + } + if e.settings.NestedFormat == NestedQueryFormatDots { + return fmt.Sprintf("%s.%s", key, subkey) + } + return fmt.Sprintf("%s[%s]", key, subkey) +} + +func (e *encoder) newArrayTypeEncoder(t reflect.Type) encoderFunc { + switch e.settings.ArrayFormat { + case ArrayQueryFormatComma: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, v reflect.Value) ([]Pair, error) { + elements := []string{} + for i := 0; i < v.Len(); i++ { + innerPairs, err := innerEncoder("", v.Index(i)) + if err != nil { + return nil, err + } + for _, pair := range innerPairs { + elements = append(elements, pair.value) + } + } + if len(elements) == 0 { + return []Pair{}, nil + } + return []Pair{{key, strings.Join(elements, ",")}}, nil + } + case ArrayQueryFormatRepeat: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + for i := 0; i < value.Len(); i++ { + subpairs, suberr := innerEncoder(key, value.Index(i)) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return pairs, err + } + case ArrayQueryFormatIndices: + panic("The array indices format is not supported yet") + case ArrayQueryFormatBrackets: + innerEncoder := e.typeEncoder(t.Elem()) + return func(key string, value reflect.Value) (pairs []Pair, err error) { + pairs = []Pair{} + for i := 0; i < value.Len(); i++ { + subpairs, suberr := innerEncoder(key+"[]", value.Index(i)) + if suberr != nil { + err = suberr + } + pairs = append(pairs, subpairs...) + } + return pairs, err + } + default: + panic(fmt.Sprintf("Unknown ArrayFormat value: %d", e.settings.ArrayFormat)) + } +} + +func (e *encoder) newPrimitiveTypeEncoder(t reflect.Type) encoderFunc { + switch t.Kind() { + case reflect.Pointer: + inner := t.Elem() + + innerEncoder := e.newPrimitiveTypeEncoder(inner) + return func(key string, v reflect.Value) ([]Pair, error) { + if !v.IsValid() || v.IsNil() { + return nil, nil + } + return innerEncoder(key, v.Elem()) + } + case reflect.String: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, v.String()}}, nil + } + case reflect.Bool: + return func(key string, v reflect.Value) ([]Pair, error) { + if v.Bool() { + return []Pair{{key, "true"}}, nil + } + return []Pair{{key, "false"}}, nil + } + case reflect.Int, reflect.Int16, reflect.Int32, reflect.Int64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatInt(v.Int(), 10)}}, nil + } + case reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatUint(v.Uint(), 10)}}, nil + } + case reflect.Float32, reflect.Float64: + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatFloat(v.Float(), 'f', -1, 64)}}, nil + } + case reflect.Complex64, reflect.Complex128: + bitSize := 64 + if t.Kind() == reflect.Complex128 { + bitSize = 128 + } + return func(key string, v reflect.Value) ([]Pair, error) { + return []Pair{{key, strconv.FormatComplex(v.Complex(), 'f', -1, bitSize)}}, nil + } + default: + return func(key string, v reflect.Value) ([]Pair, error) { + return nil, nil + } + } +} + +func (e *encoder) newTimeTypeEncoder(_ reflect.Type) encoderFunc { + format := e.dateFormat + return func(key string, value reflect.Value) ([]Pair, error) { + return []Pair{{ + key, + value.Convert(reflect.TypeOf(time.Time{})).Interface().(time.Time).Format(format), + }}, nil + } +} + +func (e encoder) newInterfaceEncoder() encoderFunc { + return func(key string, value reflect.Value) ([]Pair, error) { + value = value.Elem() + if !value.IsValid() { + return nil, nil + } + return e.typeEncoder(value.Type())(key, value) + } + +} diff --git a/internal/apiquery/query.go b/internal/apiquery/query.go new file mode 100644 index 0000000..0f379fa --- /dev/null +++ b/internal/apiquery/query.go @@ -0,0 +1,55 @@ +package apiquery + +import ( + "net/url" + "reflect" + "time" +) + +func MarshalWithSettings(value any, settings QuerySettings) (url.Values, error) { + e := encoder{time.RFC3339, true, settings} + kv := url.Values{} + val := reflect.ValueOf(value) + if !val.IsValid() { + return nil, nil + } + typ := val.Type() + + pairs, err := e.typeEncoder(typ)("", val) + if err != nil { + return nil, err + } + for _, pair := range pairs { + kv.Add(pair.key, pair.value) + } + return kv, nil +} + +func Marshal(value any) (url.Values, error) { + return MarshalWithSettings(value, QuerySettings{}) +} + +type Queryer interface { + URLQuery() (url.Values, error) +} + +type QuerySettings struct { + NestedFormat NestedQueryFormat + ArrayFormat ArrayQueryFormat +} + +type NestedQueryFormat int + +const ( + NestedQueryFormatBrackets NestedQueryFormat = iota + NestedQueryFormatDots +) + +type ArrayQueryFormat int + +const ( + ArrayQueryFormatComma ArrayQueryFormat = iota + ArrayQueryFormatRepeat + ArrayQueryFormatIndices + ArrayQueryFormatBrackets +) diff --git a/internal/apiquery/richparam.go b/internal/apiquery/richparam.go new file mode 100644 index 0000000..f2ab9ac --- /dev/null +++ b/internal/apiquery/richparam.go @@ -0,0 +1,19 @@ +package apiquery + +import ( + "scalar-api/packages/param" + "reflect" +) + +func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc { + f, _ := t.FieldByName("Value") + enc := e.typeEncoder(f.Type) + return func(key string, value reflect.Value) ([]Pair, error) { + if opt, ok := value.Interface().(param.Optional); ok && opt.Valid() { + return enc(key, value.FieldByIndex(f.Index)) + } else if ok && param.IsNull(opt) { + return []Pair{{key, "null"}}, nil + } + return nil, nil + } +} diff --git a/internal/apiquery/tag.go b/internal/apiquery/tag.go new file mode 100644 index 0000000..9e413ad --- /dev/null +++ b/internal/apiquery/tag.go @@ -0,0 +1,44 @@ +package apiquery + +import ( + "reflect" + "strings" +) + +const queryStructTag = "query" +const formatStructTag = "format" + +type parsedStructTag struct { + name string + omitempty bool + omitzero bool + inline bool +} + +func parseQueryStructTag(field reflect.StructField) (tag parsedStructTag, ok bool) { + raw, ok := field.Tag.Lookup(queryStructTag) + if !ok { + return tag, ok + } + parts := strings.Split(raw, ",") + if len(parts) == 0 { + return tag, false + } + tag.name = parts[0] + for _, part := range parts[1:] { + switch part { + case "omitzero": + tag.omitzero = true + case "omitempty": + tag.omitempty = true + case "inline": + tag.inline = true + } + } + return tag, ok +} + +func parseFormatStructTag(field reflect.StructField) (format string, ok bool) { + format, ok = field.Tag.Lookup(formatStructTag) + return format, ok +} diff --git a/internal/encoding/json/decode.go b/internal/encoding/json/decode.go new file mode 100644 index 0000000..ed3be3f --- /dev/null +++ b/internal/encoding/json/decode.go @@ -0,0 +1,1324 @@ +// Vendored from Go 1.24.0-pre-release +// To find alterations, check package shims, and comments beginning in SHIM(). +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Represents JSON data structure using native Go types: booleans, floats, +// strings, arrays, and maps. + +package json + +import ( + "encoding" + "encoding/base64" + "fmt" + "scalar-api/internal/encoding/json/shims" + "reflect" + "strconv" + "strings" + "unicode" + "unicode/utf16" + "unicode/utf8" + _ "unsafe" // for linkname +) + +// Unmarshal parses the JSON-encoded data and stores the result +// in the value pointed to by v. If v is nil or not a pointer, +// Unmarshal returns an [InvalidUnmarshalError]. +// +// Unmarshal uses the inverse of the encodings that +// [Marshal] uses, allocating maps, slices, and pointers as necessary, +// with the following additional rules: +// +// To unmarshal JSON into a pointer, Unmarshal first handles the case of +// the JSON being the JSON literal null. In that case, Unmarshal sets +// the pointer to nil. Otherwise, Unmarshal unmarshals the JSON into +// the value pointed at by the pointer. If the pointer is nil, Unmarshal +// allocates a new value for it to point to. +// +// To unmarshal JSON into a value implementing [Unmarshaler], +// Unmarshal calls that value's [Unmarshaler.UnmarshalJSON] method, including +// when the input is a JSON null. +// Otherwise, if the value implements [encoding.TextUnmarshaler] +// and the input is a JSON quoted string, Unmarshal calls +// [encoding.TextUnmarshaler.UnmarshalText] with the unquoted form of the string. +// +// To unmarshal JSON into a struct, Unmarshal matches incoming object +// keys to the keys used by [Marshal] (either the struct field name or its tag), +// preferring an exact match but also accepting a case-insensitive match. By +// default, object keys which don't have a corresponding struct field are +// ignored (see [Decoder.DisallowUnknownFields] for an alternative). +// +// To unmarshal JSON into an interface value, +// Unmarshal stores one of these in the interface value: +// +// - bool, for JSON booleans +// - float64, for JSON numbers +// - string, for JSON strings +// - []any, for JSON arrays +// - map[string]any, for JSON objects +// - nil for JSON null +// +// To unmarshal a JSON array into a slice, Unmarshal resets the slice length +// to zero and then appends each element to the slice. +// As a special case, to unmarshal an empty JSON array into a slice, +// Unmarshal replaces the slice with a new empty slice. +// +// To unmarshal a JSON array into a Go array, Unmarshal decodes +// JSON array elements into corresponding Go array elements. +// If the Go array is smaller than the JSON array, +// the additional JSON array elements are discarded. +// If the JSON array is smaller than the Go array, +// the additional Go array elements are set to zero values. +// +// To unmarshal a JSON object into a map, Unmarshal first establishes a map to +// use. If the map is nil, Unmarshal allocates a new map. Otherwise Unmarshal +// reuses the existing map, keeping existing entries. Unmarshal then stores +// key-value pairs from the JSON object into the map. The map's key type must +// either be any string type, an integer, or implement [encoding.TextUnmarshaler]. +// +// If the JSON-encoded data contain a syntax error, Unmarshal returns a [SyntaxError]. +// +// If a JSON value is not appropriate for a given target type, +// or if a JSON number overflows the target type, Unmarshal +// skips that field and completes the unmarshaling as best it can. +// If no more serious errors are encountered, Unmarshal returns +// an [UnmarshalTypeError] describing the earliest such error. In any +// case, it's not guaranteed that all the remaining fields following +// the problematic one will be unmarshaled into the target object. +// +// The JSON null value unmarshals into an interface, map, pointer, or slice +// by setting that Go value to nil. Because null is often used in JSON to mean +// “not present,” unmarshaling a JSON null into any other Go type has no effect +// on the value and produces no error. +// +// When unmarshaling quoted strings, invalid UTF-8 or +// invalid UTF-16 surrogate pairs are not treated as an error. +// Instead, they are replaced by the Unicode replacement +// character U+FFFD. +func Unmarshal(data []byte, v any) error { + // Check for well-formedness. + // Avoids filling out half a data structure + // before discovering a JSON syntax error. + var d decodeState + err := checkValid(data, &d.scan) + if err != nil { + return err + } + + d.init(data) + return d.unmarshal(v) +} + +// Unmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +// The input can be assumed to be a valid encoding of +// a JSON value. UnmarshalJSON must copy the JSON data +// if it wishes to retain the data after returning. +// +// By convention, to approximate the behavior of [Unmarshal] itself, +// Unmarshalers implement UnmarshalJSON([]byte("null")) as a no-op. +type Unmarshaler interface { + UnmarshalJSON([]byte) error +} + +// An UnmarshalTypeError describes a JSON value that was +// not appropriate for a value of a specific Go type. +type UnmarshalTypeError struct { + Value string // description of JSON value - "bool", "array", "number -5" + Type reflect.Type // type of Go value it could not be assigned to + Offset int64 // error occurred after reading Offset bytes + Struct string // name of the struct type containing the field + Field string // the full path from root node to the field, include embedded struct +} + +func (e *UnmarshalTypeError) Error() string { + if e.Struct != "" || e.Field != "" { + return "json: cannot unmarshal " + e.Value + " into Go struct field " + e.Struct + "." + e.Field + " of type " + e.Type.String() + } + return "json: cannot unmarshal " + e.Value + " into Go value of type " + e.Type.String() +} + +// An UnmarshalFieldError describes a JSON object key that +// led to an unexported (and therefore unwritable) struct field. +// +// Deprecated: No longer used; kept for compatibility. +type UnmarshalFieldError struct { + Key string + Type reflect.Type + Field reflect.StructField +} + +func (e *UnmarshalFieldError) Error() string { + return "json: cannot unmarshal object key " + strconv.Quote(e.Key) + " into unexported field " + e.Field.Name + " of type " + e.Type.String() +} + +// An InvalidUnmarshalError describes an invalid argument passed to [Unmarshal]. +// (The argument to [Unmarshal] must be a non-nil pointer.) +type InvalidUnmarshalError struct { + Type reflect.Type +} + +func (e *InvalidUnmarshalError) Error() string { + if e.Type == nil { + return "json: Unmarshal(nil)" + } + + if e.Type.Kind() != reflect.Pointer { + return "json: Unmarshal(non-pointer " + e.Type.String() + ")" + } + return "json: Unmarshal(nil " + e.Type.String() + ")" +} + +func (d *decodeState) unmarshal(v any) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return &InvalidUnmarshalError{reflect.TypeOf(v)} + } + + d.scan.reset() + d.scanWhile(scanSkipSpace) + // We decode rv not rv.Elem because the Unmarshaler interface + // test must be applied at the top level of the value. + err := d.value(rv) + if err != nil { + return d.addErrorContext(err) + } + return d.savedError +} + +// A Number represents a JSON number literal. +type Number string + +// String returns the literal text of the number. +func (n Number) String() string { return string(n) } + +// Float64 returns the number as a float64. +func (n Number) Float64() (float64, error) { + return strconv.ParseFloat(string(n), 64) +} + +// Int64 returns the number as an int64. +func (n Number) Int64() (int64, error) { + return strconv.ParseInt(string(n), 10, 64) +} + +// An errorContext provides context for type errors during decoding. +type errorContext struct { + Struct reflect.Type + FieldStack []string +} + +// decodeState represents the state while decoding a JSON value. +type decodeState struct { + data []byte + off int // next read offset in data + opcode int // last read result + scan scanner + errorContext *errorContext + savedError error + useNumber bool + disallowUnknownFields bool +} + +// readIndex returns the position of the last byte read. +func (d *decodeState) readIndex() int { + return d.off - 1 +} + +// phasePanicMsg is used as a panic message when we end up with something that +// shouldn't happen. It can indicate a bug in the JSON decoder, or that +// something is editing the data slice while the decoder executes. +const phasePanicMsg = "JSON decoder out of sync - data changing underfoot?" + +func (d *decodeState) init(data []byte) *decodeState { + d.data = data + d.off = 0 + d.savedError = nil + if d.errorContext != nil { + d.errorContext.Struct = nil + // Reuse the allocated space for the FieldStack slice. + d.errorContext.FieldStack = d.errorContext.FieldStack[:0] + } + return d +} + +// saveError saves the first err it is called with, +// for reporting at the end of the unmarshal. +func (d *decodeState) saveError(err error) { + if d.savedError == nil { + d.savedError = d.addErrorContext(err) + } +} + +// addErrorContext returns a new error enhanced with information from d.errorContext +func (d *decodeState) addErrorContext(err error) error { + if d.errorContext != nil && (d.errorContext.Struct != nil || len(d.errorContext.FieldStack) > 0) { + switch err := err.(type) { + case *UnmarshalTypeError: + err.Struct = d.errorContext.Struct.Name() + fieldStack := d.errorContext.FieldStack + if err.Field != "" { + fieldStack = append(fieldStack, err.Field) + } + err.Field = strings.Join(fieldStack, ".") + } + } + return err +} + +// skip scans to the end of what was started. +func (d *decodeState) skip() { + s, data, i := &d.scan, d.data, d.off + depth := len(s.parseState) + for { + op := s.step(s, data[i]) + i++ + if len(s.parseState) < depth { + d.off = i + d.opcode = op + return + } + } +} + +// scanNext processes the byte at d.data[d.off]. +func (d *decodeState) scanNext() { + if d.off < len(d.data) { + d.opcode = d.scan.step(&d.scan, d.data[d.off]) + d.off++ + } else { + d.opcode = d.scan.eof() + d.off = len(d.data) + 1 // mark processed EOF with len+1 + } +} + +// scanWhile processes bytes in d.data[d.off:] until it +// receives a scan code not equal to op. +func (d *decodeState) scanWhile(op int) { + s, data, i := &d.scan, d.data, d.off + for i < len(data) { + newOp := s.step(s, data[i]) + i++ + if newOp != op { + d.opcode = newOp + d.off = i + return + } + } + + d.off = len(data) + 1 // mark processed EOF with len+1 + d.opcode = d.scan.eof() +} + +// rescanLiteral is similar to scanWhile(scanContinue), but it specialises the +// common case where we're decoding a literal. The decoder scans the input +// twice, once for syntax errors and to check the length of the value, and the +// second to perform the decoding. +// +// Only in the second step do we use decodeState to tokenize literals, so we +// know there aren't any syntax errors. We can take advantage of that knowledge, +// and scan a literal's bytes much more quickly. +func (d *decodeState) rescanLiteral() { + data, i := d.data, d.off +Switch: + switch data[i-1] { + case '"': // string + for ; i < len(data); i++ { + switch data[i] { + case '\\': + i++ // escaped char + case '"': + i++ // tokenize the closing quote too + break Switch + } + } + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-': // number + for ; i < len(data); i++ { + switch data[i] { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', + '.', 'e', 'E', '+', '-': + default: + break Switch + } + } + case 't': // true + i += len("rue") + case 'f': // false + i += len("alse") + case 'n': // null + i += len("ull") + } + if i < len(data) { + d.opcode = stateEndValue(&d.scan, data[i]) + } else { + d.opcode = scanEnd + } + d.off = i + 1 +} + +// value consumes a JSON value from d.data[d.off-1:], decoding into v, and +// reads the following byte ahead. If v is invalid, the value is discarded. +// The first byte of the value has been read already. +func (d *decodeState) value(v reflect.Value) error { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray: + if v.IsValid() { + if err := d.array(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginObject: + if v.IsValid() { + if err := d.object(v); err != nil { + return err + } + } else { + d.skip() + } + d.scanNext() + + case scanBeginLiteral: + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + if v.IsValid() { + if err := d.literalStore(d.data[start:d.readIndex()], v, false); err != nil { + return err + } + } + } + return nil +} + +type unquotedValue struct{} + +// valueQuoted is like value but decodes a +// quoted string literal or literal null into an interface value. +// If it finds anything other than a quoted string literal or null, +// valueQuoted returns unquotedValue{}. +func (d *decodeState) valueQuoted() any { + switch d.opcode { + default: + panic(phasePanicMsg) + + case scanBeginArray, scanBeginObject: + d.skip() + d.scanNext() + + case scanBeginLiteral: + v := d.literalInterface() + switch v.(type) { + case nil, string: + return v + } + } + return unquotedValue{} +} + +// indirect walks down v allocating pointers as needed, +// until it gets to a non-pointer. +// If it encounters an Unmarshaler, indirect stops and returns that. +// If decodingNull is true, indirect stops at the first settable pointer so it +// can be set to nil. +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { + // Issue #24153 indicates that it is generally not a guaranteed property + // that you may round-trip a reflect.Value by calling Value.Addr().Elem() + // and expect the value to still be settable for values derived from + // unexported embedded struct fields. + // + // The logic below effectively does this when it first addresses the value + // (to satisfy possible pointer methods) and continues to dereference + // subsequent pointers as necessary. + // + // After the first round-trip, we set v back to the original value to + // preserve the original RW flags contained in reflect.Value. + v0 := v + haveAddr := false + + // If v is a named type and is addressable, + // start with its address, so that if the type has pointer methods, + // we find them. + if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() { + haveAddr = true + v = v.Addr() + } + for { + // Load value from interface, but only if the result will be + // usefully addressable. + if v.Kind() == reflect.Interface && !v.IsNil() { + e := v.Elem() + if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) { + haveAddr = false + v = e + continue + } + } + + if v.Kind() != reflect.Pointer { + break + } + + if decodingNull && v.CanSet() { + break + } + + // Prevent infinite loop if v is an interface pointing to its own address: + // var v any + // v = &v + if v.Elem().Kind() == reflect.Interface && v.Elem().Elem().Equal(v) { + v = v.Elem() + break + } + if v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) + } + if v.Type().NumMethod() > 0 && v.CanInterface() { + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if !decodingNull { + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} + } + } + } + + if haveAddr { + v = v0 // restore original value after round-trip Value.Addr().Elem() + haveAddr = false + } else { + v = v.Elem() + } + } + return nil, nil, v +} + +// array consumes an array from d.data[d.off-1:], decoding into v. +// The first byte of the array ('[') has been read already. +func (d *decodeState) array(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + + // Check type of target. + switch v.Kind() { + case reflect.Interface: + if v.NumMethod() == 0 { + // Decoding into nil interface? Switch to non-reflect code. + ai := d.arrayInterface() + v.Set(reflect.ValueOf(ai)) + return nil + } + // Otherwise it's invalid. + fallthrough + default: + d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + case reflect.Array, reflect.Slice: + break + } + + i := 0 + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + // Expand slice length, growing the slice if necessary. + if v.Kind() == reflect.Slice { + if i >= v.Cap() { + v.Grow(1) + } + if i >= v.Len() { + v.SetLen(i + 1) + } + } + + if i < v.Len() { + // Decode into element. + if err := d.value(v.Index(i)); err != nil { + return err + } + } else { + // Ran out of fixed array: skip. + if err := d.value(reflect.Value{}); err != nil { + return err + } + } + i++ + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + + if i < v.Len() { + if v.Kind() == reflect.Array { + for ; i < v.Len(); i++ { + v.Index(i).SetZero() // zero remainder of array + } + } else { + v.SetLen(i) // truncate the slice + } + } + if i == 0 && v.Kind() == reflect.Slice { + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) + } + return nil +} + +var nullLiteral = []byte("null") + +// SHIM(reflect): reflect.TypeFor[T]() reflect.T +var textUnmarshalerType = shims.TypeFor[encoding.TextUnmarshaler]() + +// object consumes an object from d.data[d.off-1:], decoding into v. +// The first byte ('{') of the object has been read already. +func (d *decodeState) object(v reflect.Value) error { + // Check for unmarshaler. + u, ut, pv := indirect(v, false) + if u != nil { + start := d.readIndex() + d.skip() + return u.UnmarshalJSON(d.data[start:d.off]) + } + if ut != nil { + d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) + d.skip() + return nil + } + v = pv + t := v.Type() + + // Decoding into nil interface? Switch to non-reflect code. + if v.Kind() == reflect.Interface && v.NumMethod() == 0 { + oi := d.objectInterface() + v.Set(reflect.ValueOf(oi)) + return nil + } + + var fields structFields + + // Check type of target: + // struct or + // map[T1]T2 where T1 is string, an integer type, + // or an encoding.TextUnmarshaler + switch v.Kind() { + case reflect.Map: + // Map key must either have string kind, have an integer kind, + // or be an encoding.TextUnmarshaler. + switch t.Key().Kind() { + case reflect.String, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + default: + if !reflect.PointerTo(t.Key()).Implements(textUnmarshalerType) { + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + } + if v.IsNil() { + v.Set(reflect.MakeMap(t)) + } + case reflect.Struct: + fields = cachedTypeFields(t) + // ok + default: + d.saveError(&UnmarshalTypeError{Value: "object", Type: t, Offset: int64(d.off)}) + d.skip() + return nil + } + + var mapElem reflect.Value + var origErrorContext errorContext + if d.errorContext != nil { + origErrorContext = *d.errorContext + } + + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquoteBytes(item) + if !ok { + panic(phasePanicMsg) + } + + // Figure out field corresponding to key. + var subv reflect.Value + destring := false // whether the value is wrapped in a string to be decoded first + + if v.Kind() == reflect.Map { + elemType := t.Elem() + if !mapElem.IsValid() { + mapElem = reflect.New(elemType).Elem() + } else { + mapElem.SetZero() + } + subv = mapElem + } else { + f := fields.byExactName[string(key)] + if f == nil { + f = fields.byFoldedName[string(foldName(key))] + } + if f != nil { + subv = v + destring = f.quoted + if d.errorContext == nil { + d.errorContext = new(errorContext) + } + for i, ind := range f.index { + if subv.Kind() == reflect.Pointer { + if subv.IsNil() { + // If a struct embeds a pointer to an unexported type, + // it is not possible to set a newly allocated value + // since the field is unexported. + // + // See https://golang.org/issue/21357 + if !subv.CanSet() { + d.saveError(fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", subv.Type().Elem())) + // Invalidate subv to ensure d.value(subv) skips over + // the JSON value without assigning it to subv. + subv = reflect.Value{} + destring = false + break + } + subv.Set(reflect.New(subv.Type().Elem())) + } + subv = subv.Elem() + } + if i < len(f.index)-1 { + d.errorContext.FieldStack = append( + d.errorContext.FieldStack, + subv.Type().Field(ind).Name, + ) + } + subv = subv.Field(ind) + } + d.errorContext.Struct = t + d.errorContext.FieldStack = append(d.errorContext.FieldStack, f.name) + } else if d.disallowUnknownFields { + d.saveError(fmt.Errorf("json: unknown field %q", key)) + } + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + if destring { + switch qv := d.valueQuoted().(type) { + case nil: + if err := d.literalStore(nullLiteral, subv, false); err != nil { + return err + } + case string: + if err := d.literalStore([]byte(qv), subv, true); err != nil { + return err + } + default: + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", subv.Type())) + } + } else { + if err := d.value(subv); err != nil { + return err + } + } + + // Write value back to map; + // if using struct, subv points into struct already. + if v.Kind() == reflect.Map { + kt := t.Key() + var kv reflect.Value + if reflect.PointerTo(kt).Implements(textUnmarshalerType) { + kv = reflect.New(kt) + if err := d.literalStore(item, kv, true); err != nil { + return err + } + kv = kv.Elem() + } else { + switch kt.Kind() { + case reflect.String: + kv = reflect.New(kt).Elem() + kv.SetString(string(key)) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + s := string(key) + n, err := strconv.ParseInt(s, 10, 64) + // SHIM(reflect): reflect.Type.OverflowInt(int64) bool + okt := shims.OverflowableType{Type: kt} + if err != nil || okt.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.New(kt).Elem() + kv.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + s := string(key) + n, err := strconv.ParseUint(s, 10, 64) + // SHIM(reflect): reflect.Type.OverflowUint(uint64) bool + okt := shims.OverflowableType{Type: kt} + if err != nil || okt.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + s, Type: kt, Offset: int64(start + 1)}) + break + } + kv = reflect.New(kt).Elem() + kv.SetUint(n) + default: + panic("json: Unexpected key type") // should never occur + } + } + if kv.IsValid() { + v.SetMapIndex(kv, subv) + } + } + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.errorContext != nil { + // Reset errorContext to its original state. + // Keep the same underlying array for FieldStack, to reuse the + // space and avoid unnecessary allocs. + d.errorContext.FieldStack = d.errorContext.FieldStack[:len(origErrorContext.FieldStack)] + d.errorContext.Struct = origErrorContext.Struct + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return nil +} + +// convertNumber converts the number literal s to a float64 or a Number +// depending on the setting of d.useNumber. +func (d *decodeState) convertNumber(s string) (any, error) { + if d.useNumber { + return Number(s), nil + } + f, err := strconv.ParseFloat(s, 64) + if err != nil { + // SHIM(reflect): reflect.TypeFor[T]() reflect.Type + return nil, &UnmarshalTypeError{Value: "number " + s, Type: shims.TypeFor[float64](), Offset: int64(d.off)} + } + return f, nil +} + +// SHIM(reflect): TypeFor[T]() reflect.Type +var numberType = shims.TypeFor[Number]() + +// literalStore decodes a literal stored in item into v. +// +// fromQuoted indicates whether this literal came from unwrapping a +// string from the ",string" struct tag option. this is used only to +// produce more helpful error messages. +func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool) error { + // Check for unmarshaler. + if len(item) == 0 { + // Empty string given. + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + isNull := item[0] == 'n' // null + u, ut, pv := indirect(v, isNull) + if u != nil { + return u.UnmarshalJSON(item) + } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + return nil + } + val := "number" + switch item[0] { + case 'n': + val = "null" + case 't', 'f': + val = "bool" + } + d.saveError(&UnmarshalTypeError{Value: val, Type: v.Type(), Offset: int64(d.readIndex())}) + return nil + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + return ut.UnmarshalText(s) + } + + v = pv + + switch c := item[0]; c { + case 'n': // null + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "null" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice: + v.SetZero() + // otherwise, ignore null for primitives/string + } + case 't', 'f': // true, false + value := item[0] == 't' + // The main parser checks that only true and false can reach here, + // but if this was a quoted string input, it could be anything. + if fromQuoted && string(item) != "true" && string(item) != "false" { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + break + } + switch v.Kind() { + default: + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + case reflect.Bool: + v.SetBool(value) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(value)) + } else { + d.saveError(&UnmarshalTypeError{Value: "bool", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + case '"': // string + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Slice: + if v.Type().Elem().Kind() != reflect.Uint8 { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + b := make([]byte, base64.StdEncoding.DecodedLen(len(s))) + n, err := base64.StdEncoding.Decode(b, s) + if err != nil { + d.saveError(err) + break + } + v.SetBytes(b[:n]) + case reflect.String: + t := string(s) + if v.Type() == numberType && !isValidNumber(t) { + return fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", item) + } + v.SetString(t) + case reflect.Interface: + if v.NumMethod() == 0 { + v.Set(reflect.ValueOf(string(s))) + } else { + d.saveError(&UnmarshalTypeError{Value: "string", Type: v.Type(), Offset: int64(d.readIndex())}) + } + } + + default: // number + if c != '-' && (c < '0' || c > '9') { + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + panic(phasePanicMsg) + } + switch v.Kind() { + default: + if v.Kind() == reflect.String && v.Type() == numberType { + // s must be a valid number, because it's + // already been tokenized. + v.SetString(string(item)) + break + } + if fromQuoted { + return fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()) + } + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + case reflect.Interface: + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + break + } + if v.NumMethod() != 0 { + d.saveError(&UnmarshalTypeError{Value: "number", Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.Set(reflect.ValueOf(n)) + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(string(item), 10, 64) + if err != nil || v.OverflowInt(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetInt(n) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(string(item), 10, 64) + if err != nil || v.OverflowUint(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetUint(n) + + case reflect.Float32, reflect.Float64: + n, err := strconv.ParseFloat(string(item), v.Type().Bits()) + if err != nil || v.OverflowFloat(n) { + d.saveError(&UnmarshalTypeError{Value: "number " + string(item), Type: v.Type(), Offset: int64(d.readIndex())}) + break + } + v.SetFloat(n) + } + } + return nil +} + +// The xxxInterface routines build up a value to be stored +// in an empty interface. They are not strictly necessary, +// but they avoid the weight of reflection in this common case. + +// valueInterface is like value but returns any. +func (d *decodeState) valueInterface() (val any) { + switch d.opcode { + default: + panic(phasePanicMsg) + case scanBeginArray: + val = d.arrayInterface() + d.scanNext() + case scanBeginObject: + val = d.objectInterface() + d.scanNext() + case scanBeginLiteral: + val = d.literalInterface() + } + return +} + +// arrayInterface is like array but returns []any. +func (d *decodeState) arrayInterface() []any { + var v = make([]any, 0) + for { + // Look ahead for ] - can only happen on first iteration. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndArray { + break + } + + v = append(v, d.valueInterface()) + + // Next token must be , or ]. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndArray { + break + } + if d.opcode != scanArrayValue { + panic(phasePanicMsg) + } + } + return v +} + +// objectInterface is like object but returns map[string]any. +func (d *decodeState) objectInterface() map[string]any { + m := make(map[string]any) + for { + // Read opening " of string key or closing }. + d.scanWhile(scanSkipSpace) + if d.opcode == scanEndObject { + // closing } - can only happen on first iteration. + break + } + if d.opcode != scanBeginLiteral { + panic(phasePanicMsg) + } + + // Read string key. + start := d.readIndex() + d.rescanLiteral() + item := d.data[start:d.readIndex()] + key, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + + // Read : before value. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode != scanObjectKey { + panic(phasePanicMsg) + } + d.scanWhile(scanSkipSpace) + + // Read value. + m[key] = d.valueInterface() + + // Next token must be , or }. + if d.opcode == scanSkipSpace { + d.scanWhile(scanSkipSpace) + } + if d.opcode == scanEndObject { + break + } + if d.opcode != scanObjectValue { + panic(phasePanicMsg) + } + } + return m +} + +// literalInterface consumes and returns a literal from d.data[d.off-1:] and +// it reads the following byte ahead. The first byte of the literal has been +// read already (that's how the caller knows it's a literal). +func (d *decodeState) literalInterface() any { + // All bytes inside literal return scanContinue op code. + start := d.readIndex() + d.rescanLiteral() + + item := d.data[start:d.readIndex()] + + switch c := item[0]; c { + case 'n': // null + return nil + + case 't', 'f': // true, false + return c == 't' + + case '"': // string + s, ok := unquote(item) + if !ok { + panic(phasePanicMsg) + } + return s + + default: // number + if c != '-' && (c < '0' || c > '9') { + panic(phasePanicMsg) + } + n, err := d.convertNumber(string(item)) + if err != nil { + d.saveError(err) + } + return n + } +} + +// getu4 decodes \uXXXX from the beginning of s, returning the hex value, +// or it returns -1. +func getu4(s []byte) rune { + if len(s) < 6 || s[0] != '\\' || s[1] != 'u' { + return -1 + } + var r rune + for _, c := range s[2:6] { + switch { + case '0' <= c && c <= '9': + c = c - '0' + case 'a' <= c && c <= 'f': + c = c - 'a' + 10 + case 'A' <= c && c <= 'F': + c = c - 'A' + 10 + default: + return -1 + } + r = r*16 + rune(c) + } + return r +} + +// unquote converts a quoted JSON string literal s into an actual string t. +// The rules are different than for Go, so cannot use strconv.Unquote. +func unquote(s []byte) (t string, ok bool) { + s, ok = unquoteBytes(s) + t = string(s) + return +} + +// unquoteBytes should be an internal detail, +// but widely used packages access it using linkname. +// Notable members of the hall of shame include: +// - github.com/bytedance/sonic +// +// Do not remove or change the type signature. +// See go.dev/issue/67401. +// +//go:linkname unquoteBytes +func unquoteBytes(s []byte) (t []byte, ok bool) { + if len(s) < 2 || s[0] != '"' || s[len(s)-1] != '"' { + return + } + s = s[1 : len(s)-1] + + // Check for unusual characters. If there are none, + // then no unquoting is needed, so return a slice of the + // original bytes. + r := 0 + for r < len(s) { + c := s[r] + if c == '\\' || c == '"' || c < ' ' { + break + } + if c < utf8.RuneSelf { + r++ + continue + } + rr, size := utf8.DecodeRune(s[r:]) + if rr == utf8.RuneError && size == 1 { + break + } + r += size + } + if r == len(s) { + return s, true + } + + b := make([]byte, len(s)+2*utf8.UTFMax) + w := copy(b, s[0:r]) + for r < len(s) { + // Out of room? Can only happen if s is full of + // malformed UTF-8 and we're replacing each + // byte with RuneError. + if w >= len(b)-2*utf8.UTFMax { + nb := make([]byte, (len(b)+utf8.UTFMax)*2) + copy(nb, b[0:w]) + b = nb + } + switch c := s[r]; { + case c == '\\': + r++ + if r >= len(s) { + return + } + switch s[r] { + default: + return + case '"', '\\', '/', '\'': + b[w] = s[r] + r++ + w++ + case 'b': + b[w] = '\b' + r++ + w++ + case 'f': + b[w] = '\f' + r++ + w++ + case 'n': + b[w] = '\n' + r++ + w++ + case 'r': + b[w] = '\r' + r++ + w++ + case 't': + b[w] = '\t' + r++ + w++ + case 'u': + r-- + rr := getu4(s[r:]) + if rr < 0 { + return + } + r += 6 + if utf16.IsSurrogate(rr) { + rr1 := getu4(s[r:]) + if dec := utf16.DecodeRune(rr, rr1); dec != unicode.ReplacementChar { + // A valid pair; consume. + r += 6 + w += utf8.EncodeRune(b[w:], dec) + break + } + // Invalid surrogate; fall back to replacement rune. + rr = unicode.ReplacementChar + } + w += utf8.EncodeRune(b[w:], rr) + } + + // Quote, control characters are invalid. + case c == '"', c < ' ': + return + + // ASCII + case c < utf8.RuneSelf: + b[w] = c + r++ + w++ + + // Coerce to well-formed UTF-8. + default: + rr, size := utf8.DecodeRune(s[r:]) + r += size + w += utf8.EncodeRune(b[w:], rr) + } + } + return b[0:w], true +} diff --git a/internal/encoding/json/encode.go b/internal/encoding/json/encode.go new file mode 100644 index 0000000..546893d --- /dev/null +++ b/internal/encoding/json/encode.go @@ -0,0 +1,1395 @@ +// Vendored from Go 1.24.0-pre-release +// To find alterations, check package shims, and comments beginning in SHIM(). +// +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package json implements encoding and decoding of JSON as defined in +// RFC 7159. The mapping between JSON and Go values is described +// in the documentation for the Marshal and Unmarshal functions. +// +// See "JSON and Go" for an introduction to this package: +// https://golang.org/doc/articles/json_and_go.html +package json + +import ( + "bytes" + "cmp" + "encoding" + "encoding/base64" + "fmt" + "scalar-api/internal/encoding/json/sentinel" + "scalar-api/internal/encoding/json/shims" + "math" + "reflect" + "slices" + "strconv" + "strings" + "sync" + "unicode" + "unicode/utf8" + _ "unsafe" // for linkname +) + +// Marshal returns the JSON encoding of v. +// +// Marshal traverses the value v recursively. +// If an encountered value implements [Marshaler] +// and is not a nil pointer, Marshal calls [Marshaler.MarshalJSON] +// to produce JSON. If no [Marshaler.MarshalJSON] method is present but the +// value implements [encoding.TextMarshaler] instead, Marshal calls +// [encoding.TextMarshaler.MarshalText] and encodes the result as a JSON string. +// The nil pointer exception is not strictly necessary +// but mimics a similar, necessary exception in the behavior of +// [Unmarshaler.UnmarshalJSON]. +// +// Otherwise, Marshal uses the following type-dependent default encodings: +// +// Boolean values encode as JSON booleans. +// +// Floating point, integer, and [Number] values encode as JSON numbers. +// NaN and +/-Inf values will return an [UnsupportedValueError]. +// +// String values encode as JSON strings coerced to valid UTF-8, +// replacing invalid bytes with the Unicode replacement rune. +// So that the JSON will be safe to embed inside HTML