mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-11 10:42:39 +00:00
Compare commits
424 Commits
sharepoint
...
jamison/Co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f1e26b1ae1 | ||
|
|
66023dbb6d | ||
|
|
f97466e4de | ||
|
|
2cc8303e5f | ||
|
|
a92ff61f64 | ||
|
|
17551a907e | ||
|
|
9e42951fa4 | ||
|
|
dcb18c2411 | ||
|
|
2f628e39d3 | ||
|
|
fd200d46f8 | ||
|
|
ec7482619b | ||
|
|
9d1a357533 | ||
|
|
fbe823b551 | ||
|
|
1608e2f274 | ||
|
|
4dbb1fa606 | ||
|
|
19b33e4d93 | ||
|
|
e56fa57c21 | ||
|
|
5cdeb84164 | ||
|
|
5b5100a07a | ||
|
|
77f58fbad5 | ||
|
|
cf74afc65e | ||
|
|
a887bc616c | ||
|
|
fef1fd093e | ||
|
|
8d085a4ccf | ||
|
|
28310b9138 | ||
|
|
f71fab580c | ||
|
|
89593b353f | ||
|
|
91e24ae63a | ||
|
|
d2b37724d1 | ||
|
|
87f0849330 | ||
|
|
2ec7526772 | ||
|
|
bbd68e2795 | ||
|
|
e74c36001a | ||
|
|
fe593a15da | ||
|
|
27df690a8d | ||
|
|
edbe569edd | ||
|
|
5118193d16 | ||
|
|
63d3efd380 | ||
|
|
ec978d9a3f | ||
|
|
d4d98a6cd0 | ||
|
|
dc40e86dac | ||
|
|
e495f7a13e | ||
|
|
4761e4b132 | ||
|
|
6b5ab54b85 | ||
|
|
959cf444f8 | ||
|
|
2ebccea6d6 | ||
|
|
5fe7a474db | ||
|
|
9d7dc3da21 | ||
|
|
2899be4c5e | ||
|
|
64ee7fc23f | ||
|
|
e07764285d | ||
|
|
cc2e6ffa8a | ||
|
|
d3ee5c9b59 | ||
|
|
dfa0efc093 | ||
|
|
9aad4077f1 | ||
|
|
29d9ebf7b3 | ||
|
|
f1df36e306 | ||
|
|
1611604269 | ||
|
|
c2a71091dc | ||
|
|
cc008699e5 | ||
|
|
48802618db | ||
|
|
6917953b86 | ||
|
|
e7cf027f8a | ||
|
|
41fb1480bb | ||
|
|
bdc2bfdcee | ||
|
|
8816d52b27 | ||
|
|
6590f1d7ba | ||
|
|
c527f75557 | ||
|
|
472d1788a7 | ||
|
|
99e95f8205 | ||
|
|
e618bf8385 | ||
|
|
f4dcd130ba | ||
|
|
910718deaa | ||
|
|
1a7ca93b93 | ||
|
|
a615a920cb | ||
|
|
29d8b310b5 | ||
|
|
d1409ccafa | ||
|
|
e41bad9103 | ||
|
|
661dc831dc | ||
|
|
19016dd35a | ||
|
|
127b2dcc80 | ||
|
|
b015a37cea | ||
|
|
b45277a8b0 | ||
|
|
893e8da79a | ||
|
|
a51f0d7cb2 | ||
|
|
c826d0469e | ||
|
|
0f6ae6f69c | ||
|
|
d0836e2603 | ||
|
|
bda03bafca | ||
|
|
376adff94a | ||
|
|
d2d4b89286 | ||
|
|
dde7a18bb7 | ||
|
|
3f004cf02f | ||
|
|
ae893079c3 | ||
|
|
189c07a913 | ||
|
|
2b82743bf5 | ||
|
|
ba2a5a60e1 | ||
|
|
5888f9d69f | ||
|
|
23b3a0a6ae | ||
|
|
eced88fa7a | ||
|
|
f59aaa902d | ||
|
|
57349bdbd1 | ||
|
|
192639a801 | ||
|
|
c10ffbb464 | ||
|
|
091f41fd1f | ||
|
|
45d77be4eb | ||
|
|
413fa85134 | ||
|
|
108cde4f55 | ||
|
|
f88ce32bd4 | ||
|
|
911f3439ea | ||
|
|
b02590d2b2 | ||
|
|
2d75b4b1f8 | ||
|
|
7e3f7d01c2 | ||
|
|
9d6ce26ea3 | ||
|
|
41713d42a2 | ||
|
|
8afc283410 | ||
|
|
b5c873077e | ||
|
|
20a4dd32eb | ||
|
|
fde0d44bc1 | ||
|
|
8fd91b6e83 | ||
|
|
8247fdd45b | ||
|
|
8c5859ba4d | ||
|
|
62ef6f59bb | ||
|
|
7eabfa125c | ||
|
|
ee18114739 | ||
|
|
f7630f5648 | ||
|
|
e0d91b9ea7 | ||
|
|
2c0a4a60a5 | ||
|
|
3a7d4dad56 | ||
|
|
c5c236d098 | ||
|
|
b18baff4d0 | ||
|
|
eb3e15c195 | ||
|
|
47d9a9e1ac | ||
|
|
aca466b35d | ||
|
|
5176fd7386 | ||
|
|
92538084e9 | ||
|
|
2d996e05a4 | ||
|
|
b2956f795b | ||
|
|
b272085543 | ||
|
|
8193aa4fd0 | ||
|
|
52db41a00b | ||
|
|
f1cf3c4589 | ||
|
|
5322aeed90 | ||
|
|
5da8870fd2 | ||
|
|
57d3ab3b40 | ||
|
|
649c7fe8b9 | ||
|
|
e5e2bc6149 | ||
|
|
b148065e1d | ||
|
|
367808951c | ||
|
|
0f74da3302 | ||
|
|
96f7cbd25a | ||
|
|
c627cea17d | ||
|
|
a8cdc3965d | ||
|
|
60891b2f44 | ||
|
|
d2f35e1fae | ||
|
|
7a7350f387 | ||
|
|
8ef504acd5 | ||
|
|
0dbabfe445 | ||
|
|
50575d0f6b | ||
|
|
9862fbd4a6 | ||
|
|
003d94546a | ||
|
|
01d3473974 | ||
|
|
19c7809a43 | ||
|
|
98e6346152 | ||
|
|
c63fdf1c13 | ||
|
|
49b509a0a7 | ||
|
|
2b1f1fe311 | ||
|
|
3e67ea9df7 | ||
|
|
98e3602dd6 | ||
|
|
4fded5b0a1 | ||
|
|
328c305d26 | ||
|
|
f902727215 | ||
|
|
69c8aa08b3 | ||
|
|
c98aa486e4 | ||
|
|
03553114c5 | ||
|
|
6532c94230 | ||
|
|
1b32a7d94e | ||
|
|
5fd0fe192b | ||
|
|
1de522f9ae | ||
|
|
60fe3e9ad6 | ||
|
|
6aa56821d6 | ||
|
|
eda436de01 | ||
|
|
07915a6c01 | ||
|
|
2c3e9aecd1 | ||
|
|
fa29cc3849 | ||
|
|
24ac8b37d3 | ||
|
|
be8b108ae4 | ||
|
|
f380a75df3 | ||
|
|
21ec93663b | ||
|
|
d789c74024 | ||
|
|
fe014776f7 | ||
|
|
700ca0e0fc | ||
|
|
a84f8238ec | ||
|
|
4fc802e19d | ||
|
|
6cfd49439a | ||
|
|
71a1faa47e | ||
|
|
1a65217baf | ||
|
|
30fa43b5fc | ||
|
|
28332fa24b | ||
|
|
1f5050f9f6 | ||
|
|
3c1d29d3cf | ||
|
|
709e3f4ca7 | ||
|
|
dfa27c08ef | ||
|
|
13d60dcb0e | ||
|
|
30704f427f | ||
|
|
4f3c54f282 | ||
|
|
580d41dc23 | ||
|
|
897e181d67 | ||
|
|
fd322a8a10 | ||
|
|
11c54bafb5 | ||
|
|
c93617df5d | ||
|
|
0cdd438f46 | ||
|
|
31aef36f78 | ||
|
|
0c35dfc0e4 | ||
|
|
a9769757fe | ||
|
|
15d8946f40 | ||
|
|
ba79539d6d | ||
|
|
59d3725fc6 | ||
|
|
9c05bd215d | ||
|
|
4d2aa09654 | ||
|
|
16c07c8756 | ||
|
|
3fb4f5d6e6 | ||
|
|
14fab7fcdf | ||
|
|
22a335fffa | ||
|
|
b0f7466eba | ||
|
|
b1d42726b1 | ||
|
|
7d922bffc1 | ||
|
|
de7fc36fc5 | ||
|
|
7f9e37450d | ||
|
|
c7ef85b733 | ||
|
|
bd9319e592 | ||
|
|
db5955d6f2 | ||
|
|
5e447440ea | ||
|
|
78c6ca39b8 | ||
|
|
71a7cf09b3 | ||
|
|
91d30a0156 | ||
|
|
7b30752767 | ||
|
|
4450ecf07c | ||
|
|
0e6b766996 | ||
|
|
12c8cd338b | ||
|
|
ad5688bf65 | ||
|
|
d2deefd1f1 | ||
|
|
18b90d405d | ||
|
|
8394e8837b | ||
|
|
f06df891c4 | ||
|
|
d6d5e72c18 | ||
|
|
449f5d62f9 | ||
|
|
4d256c5666 | ||
|
|
2e53496f46 | ||
|
|
63a206706a | ||
|
|
28427b3e5f | ||
|
|
3cafcd8a5e | ||
|
|
f2c50b7bb5 | ||
|
|
6b28c6bbfc | ||
|
|
226e801665 | ||
|
|
be13aa1310 | ||
|
|
45d38c4906 | ||
|
|
8aab518532 | ||
|
|
da6ce10e86 | ||
|
|
aaf8253520 | ||
|
|
7c7f81b164 | ||
|
|
2d4a3c72e9 | ||
|
|
7c51712018 | ||
|
|
aa5614695d | ||
|
|
8d7255d3c4 | ||
|
|
d403498f48 | ||
|
|
9ef3095c17 | ||
|
|
a39e93a0cb | ||
|
|
46d73cdfee | ||
|
|
1e04ce78e0 | ||
|
|
f9b81c1725 | ||
|
|
3bc1b89fee | ||
|
|
01743d99d4 | ||
|
|
092c1db7e0 | ||
|
|
40ac0d859a | ||
|
|
929e58361f | ||
|
|
6d472df7c5 | ||
|
|
cfa7acd904 | ||
|
|
5c5a6f943b | ||
|
|
d04128b8b1 | ||
|
|
bbebdf8f78 | ||
|
|
161279a2d5 | ||
|
|
e5ebb45a20 | ||
|
|
320ba9cb1b | ||
|
|
f2e8cb3114 | ||
|
|
43054a28ec | ||
|
|
dc74aa7b1f | ||
|
|
bd773191c2 | ||
|
|
66dbff41e6 | ||
|
|
1dcffe38bc | ||
|
|
c35e883564 | ||
|
|
fefcd58481 | ||
|
|
bdc89d9e3f | ||
|
|
f4d777b80d | ||
|
|
da4d57b5e3 | ||
|
|
dcdcd067bd | ||
|
|
8b15a29723 | ||
|
|
763853674f | ||
|
|
429b6f3465 | ||
|
|
37d5be1b40 | ||
|
|
8ab99dbb06 | ||
|
|
52799e9c7a | ||
|
|
aef009cc97 | ||
|
|
18d1ea1770 | ||
|
|
f336ad00f4 | ||
|
|
0558e687d9 | ||
|
|
784a99e24a | ||
|
|
da1f5a11f4 | ||
|
|
5633805890 | ||
|
|
0817b45ae1 | ||
|
|
af0e4bdebc | ||
|
|
4cd2320732 | ||
|
|
90a361f0e1 | ||
|
|
194efde97b | ||
|
|
d922a42262 | ||
|
|
f00c3a486e | ||
|
|
192080c9e4 | ||
|
|
c5787dc073 | ||
|
|
d424d6462c | ||
|
|
ecea86deb6 | ||
|
|
a5c1f50a8a | ||
|
|
4a04cfd486 | ||
|
|
f22e9628db | ||
|
|
255ba10af6 | ||
|
|
563202a080 | ||
|
|
1062dc0743 | ||
|
|
0826348568 | ||
|
|
375079136d | ||
|
|
82aad5e253 | ||
|
|
beb1c49c69 | ||
|
|
c4556515be | ||
|
|
a4387f230b | ||
|
|
d91e452658 | ||
|
|
dd274f8667 | ||
|
|
2c82f0da16 | ||
|
|
26101636f2 | ||
|
|
5e2c0c6cf4 | ||
|
|
33b64db498 | ||
|
|
b925cc1a56 | ||
|
|
bac4b7c945 | ||
|
|
6f6ef1e657 | ||
|
|
885c69f460 | ||
|
|
4b837303ff | ||
|
|
d856a9befb | ||
|
|
adade353c5 | ||
|
|
3cb6ec2f85 | ||
|
|
691eebf00a | ||
|
|
905b6633e6 | ||
|
|
fd088196ff | ||
|
|
cafbf5b8be | ||
|
|
1235181559 | ||
|
|
caa2e45632 | ||
|
|
9c62e03120 | ||
|
|
0937305064 | ||
|
|
e4c06570e3 | ||
|
|
78fc7c86d7 | ||
|
|
84d3aea847 | ||
|
|
00a404d3cd | ||
|
|
787cf90d96 | ||
|
|
15fe47adc5 | ||
|
|
29958f1a52 | ||
|
|
ac7f9838bc | ||
|
|
d0fa4b3319 | ||
|
|
3fb4fb422e | ||
|
|
ba5da22ea1 | ||
|
|
9909049047 | ||
|
|
c516aa3e3c | ||
|
|
5cc6220417 | ||
|
|
15da1e0a88 | ||
|
|
e9ff00890b | ||
|
|
67747a9d93 | ||
|
|
edfc51b439 | ||
|
|
ac4fba947e | ||
|
|
c142b2db02 | ||
|
|
fb7e7e4395 | ||
|
|
113f23398e | ||
|
|
5a8716026a | ||
|
|
3389140bfd | ||
|
|
13109e7b81 | ||
|
|
56ad457168 | ||
|
|
a81aea2afc | ||
|
|
7cb5c9c4a6 | ||
|
|
3520c58a22 | ||
|
|
bd9d1bfa27 | ||
|
|
14416cc3db | ||
|
|
d7fce14d26 | ||
|
|
39a8d8ed05 | ||
|
|
82f735a434 | ||
|
|
aadb58518b | ||
|
|
0755499e0f | ||
|
|
27aaf977a2 | ||
|
|
9f707f195e | ||
|
|
3e35570f70 | ||
|
|
53b1bf3b2c | ||
|
|
5a3fa6b648 | ||
|
|
fc6a37850b | ||
|
|
aa6fec3d58 | ||
|
|
efa6005e36 | ||
|
|
921bfc72f4 | ||
|
|
812603152d | ||
|
|
6779d8fbd7 | ||
|
|
2c9826e4a9 | ||
|
|
5b54687077 | ||
|
|
0f7e2ee674 | ||
|
|
ea466648d9 | ||
|
|
a402911ee6 | ||
|
|
7ae9ba807d | ||
|
|
1f79223c42 | ||
|
|
c0c2247d5a | ||
|
|
2989ceda41 | ||
|
|
c825f5eca6 | ||
|
|
a8965def79 | ||
|
|
59e1ad51ba | ||
|
|
0e70a8f826 | ||
|
|
0891737dfd | ||
|
|
5a20112670 | ||
|
|
584f2e2638 | ||
|
|
aa24b16ec1 | ||
|
|
50aa9d7df6 | ||
|
|
bfda586054 | ||
|
|
e04392fbb1 | ||
|
|
e46c6c5175 | ||
|
|
f59792b4ac | ||
|
|
973b9456e9 |
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,186 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
73
.github/actions/build-backend-image/action.yml
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
name: "Build Backend Image"
|
||||
description: "Builds and pushes the backend Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
docker-no-cache:
|
||||
description: "Set to 'true' to disable docker build cache"
|
||||
required: false
|
||||
default: "false"
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Backend Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache
|
||||
type=registry,ref=onyxdotapp/onyx-backend:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max
|
||||
no-cache: ${{ inputs.docker-no-cache == 'true' }}
|
||||
76
.github/actions/build-integration-image/action.yml
vendored
Normal file
76
.github/actions/build-integration-image/action.yml
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
name: "Build Integration Image"
|
||||
description: "Builds and pushes the integration test image with docker bake"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build and push integration test image with Docker Bake
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
run: |
|
||||
docker buildx bake --push \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \
|
||||
--set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \
|
||||
--set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \
|
||||
--set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \
|
||||
--set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \
|
||||
integration
|
||||
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
68
.github/actions/build-model-server-image/action.yml
vendored
Normal file
@@ -0,0 +1,68 @@
|
||||
name: "Build Model Server Image"
|
||||
description: "Builds and pushes the model server Docker image with cache reuse"
|
||||
inputs:
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
ref-name:
|
||||
description: "Git ref name used for cache suffix fallback"
|
||||
required: true
|
||||
pr-number:
|
||||
description: "Optional PR number for cache suffix"
|
||||
required: false
|
||||
default: ""
|
||||
github-sha:
|
||||
description: "Commit SHA used for cache keys"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in output image tag"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Format branch name for cache
|
||||
id: format-branch
|
||||
shell: bash
|
||||
env:
|
||||
PR_NUMBER: ${{ inputs.pr-number }}
|
||||
REF_NAME: ${{ inputs.ref-name }}
|
||||
run: |
|
||||
if [ -n "${PR_NUMBER}" ]; then
|
||||
CACHE_SUFFIX="${PR_NUMBER}"
|
||||
else
|
||||
# shellcheck disable=SC2001
|
||||
CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||
fi
|
||||
echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Build and push Model Server Docker image
|
||||
uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
push: true
|
||||
tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }}
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache
|
||||
type=registry,ref=onyxdotapp/onyx-model-server:latest
|
||||
cache-to: |
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max
|
||||
type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max
|
||||
130
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
130
.github/actions/run-nightly-provider-chat-test/action.yml
vendored
Normal file
@@ -0,0 +1,130 @@
|
||||
name: "Run Nightly Provider Chat Test"
|
||||
description: "Starts required compose services and runs nightly provider integration test"
|
||||
inputs:
|
||||
provider:
|
||||
description: "Provider slug for NIGHTLY_LLM_PROVIDER"
|
||||
required: true
|
||||
models:
|
||||
description: "Comma-separated model list for NIGHTLY_LLM_MODELS"
|
||||
required: true
|
||||
provider-api-key:
|
||||
description: "API key for NIGHTLY_LLM_API_KEY"
|
||||
required: false
|
||||
default: ""
|
||||
strict:
|
||||
description: "String true/false for NIGHTLY_LLM_STRICT"
|
||||
required: true
|
||||
api-base:
|
||||
description: "Optional NIGHTLY_LLM_API_BASE"
|
||||
required: false
|
||||
default: ""
|
||||
api-version:
|
||||
description: "Optional NIGHTLY_LLM_API_VERSION"
|
||||
required: false
|
||||
default: ""
|
||||
deployment-name:
|
||||
description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME"
|
||||
required: false
|
||||
default: ""
|
||||
custom-config-json:
|
||||
description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON"
|
||||
required: false
|
||||
default: ""
|
||||
runs-on-ecr-cache:
|
||||
description: "ECR cache registry from runs-on/action"
|
||||
required: true
|
||||
run-id:
|
||||
description: "GitHub run ID used in image tags"
|
||||
required: true
|
||||
docker-username:
|
||||
description: "Docker Hub username"
|
||||
required: true
|
||||
docker-token:
|
||||
description: "Docker Hub token"
|
||||
required: true
|
||||
runs:
|
||||
using: "composite"
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ inputs.docker-username }}
|
||||
password: ${{ inputs.docker-token }}
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
shell: bash
|
||||
env:
|
||||
ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
run: |
|
||||
cat <<EOF2 > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
AUTO_LLM_UPDATE_INTERVAL_SECONDS=10
|
||||
AWS_REGION_NAME=us-west-2
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID}
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID}
|
||||
EOF2
|
||||
|
||||
- name: Start Docker containers
|
||||
shell: bash
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \
|
||||
relational_db \
|
||||
index \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
inference_model_server
|
||||
|
||||
- name: Run nightly provider integration test
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
env:
|
||||
MODELS: ${{ inputs.models }}
|
||||
NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }}
|
||||
NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }}
|
||||
NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }}
|
||||
NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }}
|
||||
NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }}
|
||||
NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }}
|
||||
NIGHTLY_LLM_STRICT: ${{ inputs.strict }}
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
RUN_ID: ${{ inputs.run-id }}
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 2
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
-e POSTGRES_USER=postgres \
|
||||
-e POSTGRES_PASSWORD=password \
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e DB_READONLY_USER=db_readonly_user \
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AWS_REGION_NAME=us-west-2 \
|
||||
-e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \
|
||||
-e NIGHTLY_LLM_MODELS="${MODELS}" \
|
||||
-e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \
|
||||
-e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \
|
||||
-e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \
|
||||
-e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \
|
||||
-e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \
|
||||
-e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \
|
||||
${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \
|
||||
/app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py
|
||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
124
.github/workflows/deployment.yml
vendored
124
.github/workflows/deployment.yml
vendored
@@ -151,7 +151,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
version: "0.9.9"
|
||||
# NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable.
|
||||
@@ -182,9 +182,53 @@ jobs:
|
||||
title: "🚨 Version Tag Check Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
|
||||
build-desktop:
|
||||
# Create GitHub release first, before desktop builds start.
|
||||
# This ensures all desktop matrix jobs upload to the same release instead of
|
||||
# racing to create duplicate releases.
|
||||
create-release:
|
||||
needs: determine-builds
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
permissions:
|
||||
contents: write
|
||||
outputs:
|
||||
release-id: ${{ steps.create-release.outputs.id }}
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Determine release tag
|
||||
id: release-tag
|
||||
env:
|
||||
IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }}
|
||||
SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }}
|
||||
run: |
|
||||
if [ "${IS_TEST_RUN}" == "true" ]; then
|
||||
echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT"
|
||||
else
|
||||
echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
- name: Create GitHub Release
|
||||
id: create-release
|
||||
uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2
|
||||
with:
|
||||
tag_name: ${{ steps.release-tag.outputs.tag }}
|
||||
name: ${{ steps.release-tag.outputs.tag }}
|
||||
body: "See the assets to download this version and install."
|
||||
draft: true
|
||||
prerelease: false
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
build-desktop:
|
||||
needs:
|
||||
- determine-builds
|
||||
- create-release
|
||||
if: needs.determine-builds.outputs.build-desktop == 'true'
|
||||
permissions:
|
||||
id-token: write
|
||||
contents: write
|
||||
@@ -208,12 +252,12 @@ jobs:
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2
|
||||
with:
|
||||
# NOTE: persist-credentials is needed for tauri-action to create GitHub releases.
|
||||
# NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases.
|
||||
persist-credentials: true # zizmor: ignore[artipacked]
|
||||
|
||||
- name: Configure AWS credentials
|
||||
if: startsWith(matrix.platform, 'macos-')
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -353,11 +397,9 @@ jobs:
|
||||
APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }}
|
||||
APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }}
|
||||
with:
|
||||
tagName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseName: ${{ needs.determine-builds.outputs.is-test-run != 'true' && 'v__VERSION__' || format('v0.0.0-dev+{0}', needs.determine-builds.outputs.short-sha) }}
|
||||
releaseBody: "See the assets to download this version and install."
|
||||
releaseDraft: true
|
||||
prerelease: false
|
||||
# Use the release created by the create-release job to avoid race conditions
|
||||
# when multiple matrix jobs try to create/update the same release simultaneously
|
||||
releaseId: ${{ needs.create-release.outputs.release-id }}
|
||||
assetNamePattern: "[name]_[arch][ext]"
|
||||
args: ${{ matrix.args }}
|
||||
|
||||
@@ -384,7 +426,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -426,8 +468,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max
|
||||
@@ -457,7 +500,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -499,8 +542,9 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max
|
||||
@@ -525,7 +569,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -595,7 +639,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -646,8 +690,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max
|
||||
@@ -677,7 +721,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -728,8 +772,8 @@ jobs:
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max
|
||||
@@ -754,7 +798,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -821,7 +865,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -862,8 +906,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max
|
||||
@@ -893,7 +938,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -934,8 +979,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max
|
||||
@@ -960,7 +1006,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1030,7 +1076,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1072,8 +1118,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max
|
||||
@@ -1103,7 +1149,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1145,8 +1191,8 @@ jobs:
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
ENABLE_CRAFT=true
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max
|
||||
@@ -1172,7 +1218,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1242,7 +1288,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1287,8 +1333,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max
|
||||
@@ -1321,7 +1368,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1366,8 +1413,9 @@ jobs:
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
cache-from: |
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge
|
||||
type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: |
|
||||
type=inline
|
||||
type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max
|
||||
@@ -1394,7 +1442,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1459,7 +1507,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1514,7 +1562,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1574,7 +1622,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -1631,7 +1679,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
51
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
51
.github/workflows/nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: Nightly LLM Provider Chat Tests
|
||||
concurrency:
|
||||
group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT)
|
||||
- cron: "30 10 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
provider-chat-test:
|
||||
uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
with:
|
||||
openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }}
|
||||
anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }}
|
||||
bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }}
|
||||
vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }}
|
||||
azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }}
|
||||
azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }}
|
||||
ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }}
|
||||
openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }}
|
||||
strict: true
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: [provider-chat-test]
|
||||
if: failure() && github.event_name == 'schedule'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.SLACK_WEBHOOK }}
|
||||
failed-jobs: provider-chat-test
|
||||
title: "🚨 Scheduled LLM Provider Chat Tests failed!"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
165
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
165
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
outputs:
|
||||
should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }}
|
||||
pr_number: ${{ steps.gate.outputs.pr_number }}
|
||||
cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }}
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR once so we can gate behavior and infer preferred actor.
|
||||
pr_json="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}")"
|
||||
pr_body="$(printf '%s' "$pr_json" | jq -r '.body // ""')"
|
||||
merged_by="$(printf '%s' "$pr_json" | jq -r '.merged_by.login // ""')"
|
||||
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
echo "merged_by=$merged_by" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
id: run_cherry_pick
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
continue-on-error: true
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
CHERRY_PICK_ASSIGNEE: ${{ steps.gate.outputs.merged_by }}
|
||||
run: |
|
||||
set -o pipefail
|
||||
output_file="$(mktemp)"
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify 2>&1 | tee "$output_file"
|
||||
exit_code="${PIPESTATUS[0]}"
|
||||
|
||||
if [ "${exit_code}" -eq 0 ]; then
|
||||
echo "status=success" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "status=failure" >> "$GITHUB_OUTPUT"
|
||||
|
||||
reason="command-failed"
|
||||
if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then
|
||||
reason="merge-conflict"
|
||||
fi
|
||||
echo "reason=${reason}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
{
|
||||
echo "details<<EOF"
|
||||
tail -n 40 "$output_file"
|
||||
echo "EOF"
|
||||
} >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Mark workflow as failed if cherry-pick failed
|
||||
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
|
||||
env:
|
||||
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
|
||||
run: |
|
||||
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
|
||||
exit 1
|
||||
|
||||
notify-slack-on-cherry-pick-failure:
|
||||
needs:
|
||||
- cherry-pick-to-latest-release
|
||||
if: always() && needs.cherry-pick-to-latest-release.outputs.should_cherrypick == 'true' && needs.cherry-pick-to-latest-release.result != 'success'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Build cherry-pick failure summary
|
||||
id: failure-summary
|
||||
env:
|
||||
SOURCE_PR_NUMBER: ${{ needs.cherry-pick-to-latest-release.outputs.pr_number }}
|
||||
CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }}
|
||||
CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }}
|
||||
run: |
|
||||
source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}"
|
||||
|
||||
reason_text="cherry-pick command failed"
|
||||
if [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then
|
||||
reason_text="merge conflict during cherry-pick"
|
||||
fi
|
||||
|
||||
details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)"
|
||||
failed_jobs="• cherry-pick-to-latest-release\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}"
|
||||
if [ -n "${details_excerpt}" ]; then
|
||||
failed_jobs="${failed_jobs}\\n• excerpt: ${details_excerpt}"
|
||||
fi
|
||||
|
||||
echo "jobs=${failed_jobs}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Notify #cherry-pick-prs about cherry-pick failure
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }}
|
||||
failed-jobs: ${{ steps.failure-summary.outputs.jobs }}
|
||||
title: "🚨 Automated Cherry-Pick Failed"
|
||||
ref-name: ${{ github.ref_name }}
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
2
.github/workflows/pr-desktop-build.yml
vendored
2
.github/workflows/pr-desktop-build.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
cache-dependency-path: ./desktop/package-lock.json
|
||||
|
||||
- name: Setup Rust
|
||||
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
|
||||
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
|
||||
with:
|
||||
toolchain: stable
|
||||
targets: ${{ matrix.target }}
|
||||
|
||||
@@ -45,9 +45,6 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -118,9 +115,9 @@ jobs:
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
@@ -129,7 +126,6 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
@@ -164,7 +160,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
56
.github/workflows/pr-golang-tests.yml
vendored
Normal file
@@ -0,0 +1,56 @@
|
||||
name: Golang Tests
|
||||
concurrency:
|
||||
group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
|
||||
permissions: {}
|
||||
|
||||
env:
|
||||
GO_VERSION: "1.26"
|
||||
|
||||
jobs:
|
||||
detect-modules:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
outputs:
|
||||
modules: ${{ steps.set-modules.outputs.modules }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8
|
||||
with:
|
||||
persist-credentials: false
|
||||
- id: set-modules
|
||||
run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT"
|
||||
|
||||
golang:
|
||||
needs: detect-modules
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }}
|
||||
steps:
|
||||
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning]
|
||||
with:
|
||||
go-version: ${{ env.GO_VERSION }}
|
||||
cache-dependency-path: "**/go.sum"
|
||||
|
||||
- run: go mod tidy
|
||||
working-directory: ${{ matrix.modules }}
|
||||
- run: git diff --exit-code go.mod go.sum
|
||||
working-directory: ${{ matrix.modules }}
|
||||
|
||||
- run: go test ./...
|
||||
working-directory: ${{ matrix.modules }}
|
||||
4
.github/workflows/pr-helm-chart-testing.yml
vendored
4
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -71,7 +71,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@92086f6be054225fa813e0a4b13787fc9088faab # ratchet:helm/kind-action@v1.13.0
|
||||
uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0
|
||||
|
||||
- name: Pre-install cluster status check
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -91,7 +91,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
|
||||
60
.github/workflows/pr-integration-tests.yml
vendored
60
.github/workflows/pr-integration-tests.yml
vendored
@@ -20,6 +20,7 @@ env:
|
||||
# Test Environment Variables
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
@@ -315,6 +316,7 @@ jobs:
|
||||
# Base config shared by both editions
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -334,7 +336,6 @@ jobs:
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=false
|
||||
EOF
|
||||
fi
|
||||
|
||||
@@ -418,11 +419,13 @@ jobs:
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
@@ -443,6 +446,7 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
-e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/${{ matrix.test-dir.path }}
|
||||
|
||||
@@ -468,13 +472,13 @@ jobs:
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
no-vectordb-tests:
|
||||
onyx-lite-tests:
|
||||
needs: [build-backend-image, build-integration-image]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=4cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-no-vectordb-tests",
|
||||
"run-id=${{ github.run_id }}-onyx-lite-tests",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
@@ -492,13 +496,12 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create .env file for no-vectordb Docker Compose
|
||||
- name: Create .env file for Onyx Lite Docker Compose
|
||||
env:
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
@@ -506,28 +509,23 @@ jobs:
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
DISABLE_VECTOR_DB=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID}
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
USE_LIGHTWEIGHT_BACKGROUND_WORKER=true
|
||||
EOF
|
||||
|
||||
# Start only the services needed for no-vectordb mode (no Vespa, no model servers)
|
||||
- name: Start Docker containers (no-vectordb)
|
||||
# Start only the services needed for Onyx Lite (Postgres + API server)
|
||||
- name: Start Docker containers (onyx-lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml up \
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \
|
||||
relational_db \
|
||||
cache \
|
||||
minio \
|
||||
api_server \
|
||||
background \
|
||||
-d
|
||||
id: start_docker_no_vectordb
|
||||
id: start_docker_onyx_lite
|
||||
|
||||
- name: Wait for services to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script (no-vectordb)..."
|
||||
echo "Starting wait-for-service script (onyx-lite)..."
|
||||
start_time=$(date +%s)
|
||||
timeout=300
|
||||
while true; do
|
||||
@@ -549,14 +547,14 @@ jobs:
|
||||
sleep 5
|
||||
done
|
||||
|
||||
- name: Run No-VectorDB Integration Tests
|
||||
- name: Run Onyx Lite Integration Tests
|
||||
uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3
|
||||
with:
|
||||
timeout_minutes: 20
|
||||
max_attempts: 3
|
||||
retry_wait_seconds: 10
|
||||
command: |
|
||||
echo "Running no-vectordb integration tests..."
|
||||
echo "Running onyx-lite integration tests..."
|
||||
docker run --rm --network onyx_default \
|
||||
--name test-runner \
|
||||
-e POSTGRES_HOST=relational_db \
|
||||
@@ -567,39 +565,38 @@ jobs:
|
||||
-e DB_READONLY_PASSWORD=password \
|
||||
-e POSTGRES_POOL_PRE_PING=true \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \
|
||||
/app/tests/integration/tests/no_vectordb
|
||||
|
||||
- name: Dump API server logs (no-vectordb)
|
||||
- name: Dump API server logs (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_no_vectordb.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true
|
||||
|
||||
- name: Dump all-container logs (no-vectordb)
|
||||
- name: Dump all-container logs (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-no-vectordb.log || true
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \
|
||||
logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true
|
||||
|
||||
- name: Upload logs (no-vectordb)
|
||||
- name: Upload logs (onyx-lite)
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-no-vectordb
|
||||
path: ${{ github.workspace }}/docker-compose-no-vectordb.log
|
||||
name: docker-all-logs-onyx-lite
|
||||
path: ${{ github.workspace }}/docker-compose-onyx-lite.log
|
||||
|
||||
- name: Stop Docker containers (no-vectordb)
|
||||
- name: Stop Docker containers (onyx-lite)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.no-vectordb.yml -f docker-compose.dev.yml down -v
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v
|
||||
|
||||
multitenant-tests:
|
||||
needs:
|
||||
@@ -642,6 +639,7 @@ jobs:
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \
|
||||
ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \
|
||||
DEV_MODE=true \
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=false \
|
||||
docker compose -f docker-compose.multitenant-dev.yml up \
|
||||
relational_db \
|
||||
index \
|
||||
@@ -696,11 +694,13 @@ jobs:
|
||||
-e POSTGRES_DB=postgres \
|
||||
-e POSTGRES_USE_NULL_POOL=true \
|
||||
-e VESPA_HOST=index \
|
||||
-e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \
|
||||
-e REDIS_HOST=cache \
|
||||
-e API_SERVER_HOST=api_server \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e EXA_API_KEY=${EXA_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
@@ -740,7 +740,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [integration-tests, no-vectordb-tests, multitenant-tests]
|
||||
needs: [integration-tests, onyx-lite-tests, multitenant-tests]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
2
.github/workflows/pr-jest-tests.yml
vendored
2
.github/workflows/pr-jest-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
|
||||
117
.github/workflows/pr-playwright-tests.yml
vendored
117
.github/workflows/pr-playwright-tests.yml
vendored
@@ -12,6 +12,9 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
# TODO: Remove this if we enable merge-queues for release branches.
|
||||
branches:
|
||||
- "release/**"
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -268,10 +271,11 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
@@ -279,6 +283,7 @@ jobs:
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
@@ -459,14 +464,14 @@ jobs:
|
||||
# --- Visual Regression Diff ---
|
||||
- name: Configure AWS credentials
|
||||
if: always()
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: always()
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -590,6 +595,108 @@ jobs:
|
||||
name: docker-logs-${{ matrix.project }}-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
playwright-tests-lite:
|
||||
needs: [build-web-image, build-backend-image]
|
||||
name: Playwright Tests (lite)
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-playwright-tests-lite"
|
||||
- "extras=ecr-cache"
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm" # zizmor: ignore[cache-poisoning]
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Cache playwright cache
|
||||
# zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: ~/.cache/ms-playwright
|
||||
key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-playwright-npm-
|
||||
|
||||
- name: Install playwright browsers
|
||||
working-directory: ./web
|
||||
run: npx playwright install --with-deps
|
||||
|
||||
- name: Create .env file for Docker Compose
|
||||
env:
|
||||
OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }}
|
||||
ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
RUN_ID: ${{ github.run_id }}
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
INTEGRATION_TESTS_MODE=true
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
MOCK_LLM_RESPONSE=true
|
||||
REQUIRE_EMAIL_VERIFICATION=false
|
||||
DISABLE_TELEMETRY=true
|
||||
ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID}
|
||||
ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID}
|
||||
EOF
|
||||
|
||||
# needed for pulling external images otherwise, we hit the "Unauthenticated users" limit
|
||||
# https://docs.docker.com/docker-hub/usage/
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Start Docker containers (lite)
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d
|
||||
id: start_docker
|
||||
|
||||
- name: Run Playwright tests (lite)
|
||||
working-directory: ./web
|
||||
run: npx playwright test --project lite
|
||||
|
||||
- uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
if: always()
|
||||
with:
|
||||
name: playwright-test-results-lite-${{ github.run_id }}
|
||||
path: ./web/output/playwright/
|
||||
retention-days: 30
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
env:
|
||||
WORKSPACE: ${{ github.workspace }}
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs > docker-compose.log
|
||||
mv docker-compose.log ${WORKSPACE}/docker-compose.log
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-logs-lite-${{ github.run_id }}
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
@@ -603,7 +710,7 @@ jobs:
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Download visual diff summaries
|
||||
uses: actions/download-artifact@95815c38cf2ff2164869cbab79da8d1f422bc89e # ratchet:actions/download-artifact@v4
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3
|
||||
with:
|
||||
pattern: screenshot-diff-summary-*
|
||||
path: summaries/
|
||||
@@ -686,7 +793,7 @@ jobs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 45
|
||||
needs: [playwright-tests]
|
||||
needs: [playwright-tests, playwright-tests-lite]
|
||||
if: ${{ always() }}
|
||||
steps:
|
||||
- name: Check job status
|
||||
|
||||
21
.github/workflows/pr-python-checks.yml
vendored
21
.github/workflows/pr-python-checks.yml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- 'release/**'
|
||||
- "release/**"
|
||||
push:
|
||||
tags:
|
||||
- "v*.*.*"
|
||||
@@ -21,7 +21,13 @@ jobs:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
# Note: Mypy seems quite optimized for x64 compared to arm64.
|
||||
# Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient.
|
||||
runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache"]
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-x64,
|
||||
"run-id=${{ github.run_id }}-mypy-check",
|
||||
"extras=s3-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
|
||||
steps:
|
||||
@@ -52,21 +58,14 @@ jobs:
|
||||
if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }}
|
||||
uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4
|
||||
with:
|
||||
path: backend/.mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'backend/pyproject.toml') }}
|
||||
path: .mypy_cache
|
||||
key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }}
|
||||
restore-keys: |
|
||||
mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-
|
||||
mypy-${{ runner.os }}-
|
||||
|
||||
- name: Run MyPy
|
||||
working-directory: ./backend
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy .
|
||||
|
||||
- name: Run MyPy (tools/)
|
||||
env:
|
||||
MYPY_FORCE_COLOR: 1
|
||||
TERM: xterm-256color
|
||||
run: mypy tools/
|
||||
|
||||
@@ -89,6 +89,10 @@ env:
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
|
||||
# Github
|
||||
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
|
||||
|
||||
6
.github/workflows/pr-quality-checks.yml
vendored
6
.github/workflows/pr-quality-checks.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
- name: Setup Terraform
|
||||
uses: hashicorp/setup-terraform@b9cd54a3c349d3f38e8881555d616ced269862dd # ratchet:hashicorp/setup-terraform@v3
|
||||
uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v6
|
||||
with: # zizmor: ignore[cache-poisoning]
|
||||
@@ -38,9 +38,9 @@ jobs:
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
- uses: j178/prek-action@9d6a3097e0c1865ecce00cfb89fe80f2ee91b547 # ratchet:j178/prek-action@v1
|
||||
- uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1
|
||||
with:
|
||||
prek-version: '0.2.21'
|
||||
prek-version: '0.3.4'
|
||||
extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }}
|
||||
- name: Check Actions
|
||||
uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1
|
||||
|
||||
214
.github/workflows/release-cli.yml
vendored
Normal file
214
.github/workflows/release-cli.yml
vendored
Normal file
@@ -0,0 +1,214 @@
|
||||
name: Release CLI
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "cli/v*.*.*"
|
||||
|
||||
jobs:
|
||||
pypi:
|
||||
runs-on: ubuntu-latest
|
||||
environment:
|
||||
name: release-cli
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
strategy:
|
||||
matrix:
|
||||
os-arch:
|
||||
- { goos: "linux", goarch: "amd64" }
|
||||
- { goos: "linux", goarch: "arm64" }
|
||||
- { goos: "windows", goarch: "amd64" }
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
- run: |
|
||||
GOOS="${{ matrix.os-arch.goos }}" \
|
||||
GOARCH="${{ matrix.os-arch.goarch }}" \
|
||||
uv build --wheel
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
4
.github/workflows/release-devtools.yml
vendored
4
.github/workflows/release-devtools.yml
vendored
@@ -22,13 +22,11 @@ jobs:
|
||||
- { goos: "windows", goarch: "arm64" }
|
||||
- { goos: "darwin", goarch: "amd64" }
|
||||
- { goos: "darwin", goarch: "arm64" }
|
||||
- { goos: "", goarch: "" }
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
333
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
333
.github/workflows/reusable-nightly-llm-provider-chat.yml
vendored
Normal file
@@ -0,0 +1,333 @@
|
||||
name: Reusable Nightly LLM Provider Chat Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
openai_models:
|
||||
description: "Comma-separated models for openai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
anthropic_models:
|
||||
description: "Comma-separated models for anthropic"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
bedrock_models:
|
||||
description: "Comma-separated models for bedrock"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
vertex_ai_models:
|
||||
description: "Comma-separated models for vertex_ai"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_models:
|
||||
description: "Comma-separated models for azure"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
ollama_models:
|
||||
description: "Comma-separated models for ollama_chat"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
openrouter_models:
|
||||
description: "Comma-separated models for openrouter"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
azure_api_base:
|
||||
description: "API base for azure provider"
|
||||
required: false
|
||||
default: ""
|
||||
type: string
|
||||
strict:
|
||||
description: "Default NIGHTLY_LLM_STRICT passed to tests"
|
||||
required: false
|
||||
default: true
|
||||
type: boolean
|
||||
secrets:
|
||||
AWS_OIDC_ROLE_ARN:
|
||||
description: "AWS role ARN for OIDC auth"
|
||||
required: true
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
id-token: write
|
||||
|
||||
jobs:
|
||||
build-backend-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-backend-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build backend image
|
||||
uses: ./.github/actions/build-backend-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }}
|
||||
|
||||
build-model-server-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=1cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-model-server-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build model server image
|
||||
uses: ./.github/actions/build-model-server-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
build-integration-image:
|
||||
runs-on:
|
||||
[
|
||||
runs-on,
|
||||
runner=2cpu-linux-arm64,
|
||||
"run-id=${{ github.run_id }}-build-integration-image",
|
||||
"extras=ecr-cache",
|
||||
]
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
|
||||
- name: Build integration image
|
||||
uses: ./.github/actions/build-integration-image
|
||||
with:
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
ref-name: ${{ github.ref_name }}
|
||||
pr-number: ${{ github.event.pull_request.number }}
|
||||
github-sha: ${{ github.sha }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
provider-chat-test:
|
||||
needs:
|
||||
[
|
||||
build-backend-image,
|
||||
build-model-server-image,
|
||||
build-integration-image,
|
||||
]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- provider: openai
|
||||
models: ${{ inputs.openai_models }}
|
||||
api_key_env: OPENAI_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: anthropic
|
||||
models: ${{ inputs.anthropic_models }}
|
||||
api_key_env: ANTHROPIC_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: true
|
||||
- provider: bedrock
|
||||
models: ${{ inputs.bedrock_models }}
|
||||
api_key_env: BEDROCK_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: vertex_ai
|
||||
models: ${{ inputs.vertex_ai_models }}
|
||||
api_key_env: ""
|
||||
custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON
|
||||
api_base: ""
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: azure
|
||||
models: ${{ inputs.azure_models }}
|
||||
api_key_env: AZURE_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: ${{ inputs.azure_api_base }}
|
||||
api_version: "2025-04-01-preview"
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: ollama_chat
|
||||
models: ${{ inputs.ollama_models }}
|
||||
api_key_env: OLLAMA_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://ollama.com"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
- provider: openrouter
|
||||
models: ${{ inputs.openrouter_models }}
|
||||
api_key_env: OPENROUTER_API_KEY
|
||||
custom_config_env: ""
|
||||
api_base: "https://openrouter.ai/api/v1"
|
||||
api_version: ""
|
||||
deployment_name: ""
|
||||
required: false
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=4cpu-linux-arm64
|
||||
- "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test"
|
||||
- extras=ecr-cache
|
||||
timeout-minutes: 45
|
||||
environment: ci-protected
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802
|
||||
with:
|
||||
# Keep JSON values unparsed so vertex custom config is passed as raw JSON.
|
||||
parse-json-secrets: false
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, test/docker-username
|
||||
DOCKER_TOKEN, test/docker-token
|
||||
OPENAI_API_KEY, test/openai-api-key
|
||||
ANTHROPIC_API_KEY, test/anthropic-api-key
|
||||
BEDROCK_API_KEY, test/bedrock-api-key
|
||||
NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json
|
||||
AZURE_API_KEY, test/azure-api-key
|
||||
OLLAMA_API_KEY, test/ollama-api-key
|
||||
OPENROUTER_API_KEY, test/openrouter-api-key
|
||||
|
||||
- name: Run nightly provider chat test
|
||||
uses: ./.github/actions/run-nightly-provider-chat-test
|
||||
with:
|
||||
provider: ${{ matrix.provider }}
|
||||
models: ${{ matrix.models }}
|
||||
provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }}
|
||||
strict: ${{ inputs.strict && 'true' || 'false' }}
|
||||
api-base: ${{ matrix.api_base }}
|
||||
api-version: ${{ matrix.api_version }}
|
||||
deployment-name: ${{ matrix.deployment_name }}
|
||||
custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }}
|
||||
runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }}
|
||||
run-id: ${{ github.run_id }}
|
||||
docker-username: ${{ env.DOCKER_USERNAME }}
|
||||
docker-token: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Dump all-container logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f
|
||||
with:
|
||||
name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider
|
||||
path: |
|
||||
${{ github.workspace }}/api_server.log
|
||||
${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose down -v
|
||||
6
.github/workflows/sandbox-deployment.yml
vendored
6
.github/workflows/sandbox-deployment.yml
vendored
@@ -110,7 +110,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -180,7 +180,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
@@ -244,7 +244,7 @@ jobs:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@61815dcd50bd041e203e49132bacad1fd04d2708
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
69
.github/workflows/storybook-deploy.yml
vendored
Normal file
@@ -0,0 +1,69 @@
|
||||
name: Storybook Deploy
|
||||
env:
|
||||
VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }}
|
||||
VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM
|
||||
VERCEL_CLI: vercel@50.14.1
|
||||
VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }}
|
||||
|
||||
concurrency:
|
||||
group: storybook-deploy-production
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "web/lib/opal/**"
|
||||
- "web/src/refresh-components/**"
|
||||
- "web/.storybook/**"
|
||||
- "web/package.json"
|
||||
- "web/package-lock.json"
|
||||
permissions:
|
||||
contents: read
|
||||
jobs:
|
||||
Deploy-Storybook:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 30
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@6044e13b5dc448c55e2357c09f80417699197238 # ratchet:actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
cache: "npm"
|
||||
cache-dependency-path: ./web/package-lock.json
|
||||
|
||||
- name: Install dependencies
|
||||
working-directory: web
|
||||
run: npm ci
|
||||
|
||||
- name: Build Storybook
|
||||
working-directory: web
|
||||
run: npm run storybook:build
|
||||
|
||||
- name: Deploy to Vercel (Production)
|
||||
working-directory: web
|
||||
run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes --token="$VERCEL_TOKEN"
|
||||
|
||||
notify-slack-on-failure:
|
||||
needs: Deploy-Storybook
|
||||
if: always() && needs.Deploy-Storybook.result == 'failure'
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 10
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
sparse-checkout: .github/actions/slack-notify
|
||||
|
||||
- name: Send Slack notification
|
||||
uses: ./.github/actions/slack-notify
|
||||
with:
|
||||
webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }}
|
||||
failed-jobs: "• Deploy-Storybook"
|
||||
title: "🚨 Storybook Deploy Failed"
|
||||
2
.github/workflows/zizmor.yml
vendored
2
.github/workflows/zizmor.yml
vendored
@@ -24,7 +24,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Install the latest version of uv
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
@@ -119,10 +119,11 @@ repos:
|
||||
]
|
||||
|
||||
- repo: https://github.com/golangci/golangci-lint
|
||||
rev: 9f61b0f53f80672872fced07b6874397c3ed197b # frozen: v2.7.2
|
||||
rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1
|
||||
hooks:
|
||||
- id: golangci-lint
|
||||
entry: bash -c "find tools/ -name go.mod -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
language_version: "1.26.0"
|
||||
entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'"
|
||||
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
|
||||
3
.vscode/env_template.txt
vendored
3
.vscode/env_template.txt
vendored
@@ -7,6 +7,9 @@
|
||||
|
||||
|
||||
AUTH_TYPE=basic
|
||||
# Recommended for basic auth - used for signing password reset and verification tokens
|
||||
# Generate a secure value with: openssl rand -hex 32
|
||||
USER_AUTH_SECRET=""
|
||||
DEV_MODE=true
|
||||
|
||||
|
||||
|
||||
58
.vscode/launch.json
vendored
58
.vscode/launch.json
vendored
@@ -40,19 +40,7 @@
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Celery (lightweight mode)",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery background",
|
||||
"Celery beat"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (standard mode)",
|
||||
"name": "Celery",
|
||||
"configurations": [
|
||||
"Celery primary",
|
||||
"Celery light",
|
||||
@@ -253,35 +241,6 @@
|
||||
},
|
||||
"consoleTitle": "Celery light Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery background",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"module": "celery",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.vscode/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "INFO",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.background",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=20",
|
||||
"--prefetch-multiplier=4",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=background@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup,index_attempt_cleanup,docprocessing,connector_doc_fetching,connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation,kg_processing,monitoring,user_file_processing,user_file_project_sync,user_file_delete,opensearch_migration"
|
||||
],
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery background Console"
|
||||
},
|
||||
{
|
||||
"name": "Celery heavy",
|
||||
"type": "debugpy",
|
||||
@@ -526,21 +485,6 @@
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Clear and Restart OpenSearch Container",
|
||||
// Generic debugger type, required arg but has no bearing on bash.
|
||||
"type": "node",
|
||||
"request": "launch",
|
||||
"runtimeExecutable": "bash",
|
||||
"runtimeArgs": [
|
||||
"${workspaceFolder}/backend/scripts/restart_opensearch_container.sh"
|
||||
],
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"presentation": {
|
||||
"group": "3"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Eval CLI",
|
||||
"type": "debugpy",
|
||||
|
||||
84
AGENTS.md
84
AGENTS.md
@@ -86,37 +86,6 @@ Onyx uses Celery for asynchronous task processing with multiple specialized work
|
||||
- Monitoring tasks (every 5 minutes)
|
||||
- Cleanup tasks (hourly)
|
||||
|
||||
#### Worker Deployment Modes
|
||||
|
||||
Onyx supports two deployment modes for background workers, controlled by the `USE_LIGHTWEIGHT_BACKGROUND_WORKER` environment variable:
|
||||
|
||||
**Lightweight Mode** (default, `USE_LIGHTWEIGHT_BACKGROUND_WORKER=true`):
|
||||
|
||||
- Runs a single consolidated `background` worker that handles all background tasks:
|
||||
- Light worker tasks (Vespa operations, permissions sync, deletion)
|
||||
- Document processing (indexing pipeline)
|
||||
- Document fetching (connector data retrieval)
|
||||
- Pruning operations (from `heavy` worker)
|
||||
- Knowledge graph processing (from `kg_processing` worker)
|
||||
- Monitoring tasks (from `monitoring` worker)
|
||||
- User file processing (from `user_file_processing` worker)
|
||||
- Lower resource footprint (fewer worker processes)
|
||||
- Suitable for smaller deployments or development environments
|
||||
- Default concurrency: 20 threads (increased to handle combined workload)
|
||||
|
||||
**Standard Mode** (`USE_LIGHTWEIGHT_BACKGROUND_WORKER=false`):
|
||||
|
||||
- Runs separate specialized workers as documented above (light, docprocessing, docfetching, heavy, kg_processing, monitoring, user_file_processing)
|
||||
- Better isolation and scalability
|
||||
- Can scale individual workers independently based on workload
|
||||
- Suitable for production deployments with higher load
|
||||
|
||||
The deployment mode affects:
|
||||
|
||||
- **Backend**: Worker processes spawned by supervisord or dev scripts
|
||||
- **Helm**: Which Kubernetes deployments are created
|
||||
- **Dev Environment**: Which workers `dev_run_background_jobs.py` spawns
|
||||
|
||||
#### Key Features
|
||||
|
||||
- **Thread-based Workers**: All workers use thread pools (not processes) for stability
|
||||
@@ -135,6 +104,10 @@ The deployment mode affects:
|
||||
|
||||
- Always use `@shared_task` rather than `@celery_app`
|
||||
- Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks`
|
||||
- Never enqueue a task without an expiration. Always supply `expires=` when
|
||||
sending tasks, either from the beat schedule or directly from another task. It
|
||||
should never be acceptable to submit code which enqueues tasks without an
|
||||
expiration, as doing so can lead to unbounded task queue growth.
|
||||
|
||||
**Defining APIs**:
|
||||
When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the
|
||||
@@ -548,7 +521,7 @@ class in the utils over directly calling the APIs with a library like `requests`
|
||||
calling the utilities directly (e.g. do NOT create admin users with
|
||||
`admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture).
|
||||
|
||||
A great example of this type of test is `backend/tests/integration/dev_apis/test_simple_chat_api.py`.
|
||||
A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`.
|
||||
|
||||
To run them:
|
||||
|
||||
@@ -571,6 +544,8 @@ To run them:
|
||||
npx playwright test <TEST_NAME>
|
||||
```
|
||||
|
||||
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
|
||||
|
||||
## Logs
|
||||
|
||||
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access
|
||||
@@ -616,3 +591,48 @@ This is a minimal list - feel free to include more. Do NOT write code as part of
|
||||
Keep it high level. You can reference certain files or functions though.
|
||||
|
||||
Before writing your plan, make sure to do research. Explore the relevant sections in the codebase.
|
||||
|
||||
## Error Handling
|
||||
|
||||
**Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`.
|
||||
Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.**
|
||||
|
||||
A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard
|
||||
`{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error
|
||||
handling consistent across the entire backend.
|
||||
|
||||
```python
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
# ✅ Good
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found")
|
||||
|
||||
# ✅ Good — no extra message needed
|
||||
raise OnyxError(OnyxErrorCode.UNAUTHENTICATED)
|
||||
|
||||
# ✅ Good — upstream service with dynamic status code
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)
|
||||
|
||||
# ❌ Bad — using HTTPException directly
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
|
||||
# ❌ Bad — starlette constant
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
```
|
||||
|
||||
Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error
|
||||
category is needed, add it there first — do not invent ad-hoc codes.
|
||||
|
||||
**Upstream service errors:** When forwarding errors from an upstream service where the HTTP
|
||||
status code is dynamic (comes from the upstream response), use `status_code_override`:
|
||||
|
||||
```python
|
||||
raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code)
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
In addition to the other content in this file, best practices for contributing
|
||||
to the codebase can be found at `contributing_guides/best_practices.md`.
|
||||
Understand its contents and follow them.
|
||||
|
||||
@@ -46,7 +46,9 @@ RUN apt-get update && \
|
||||
pkg-config \
|
||||
gcc \
|
||||
nano \
|
||||
vim && \
|
||||
vim \
|
||||
libjemalloc2 \
|
||||
&& \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
apt-get clean
|
||||
|
||||
@@ -141,6 +143,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
@@ -164,6 +167,13 @@ ENV PYTHONPATH=/app
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION}
|
||||
|
||||
# Use jemalloc instead of glibc malloc to reduce memory fragmentation
|
||||
# in long-running Python processes (API server, Celery workers).
|
||||
# The soname is architecture-independent; the dynamic linker resolves
|
||||
# the correct path from standard library directories.
|
||||
# Placed after all RUN steps so build-time processes are unaffected.
|
||||
ENV LD_PRELOAD=libjemalloc.so.2
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -21,15 +21,14 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, NamedTuple
|
||||
from typing import NamedTuple
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
|
||||
return script.get_current_head()
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: List[str], head_rev: str
|
||||
) -> List[str]:
|
||||
"""Return only schemas whose current alembic version is not at head."""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Find which schemas actually have an alembic_version table
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"SELECT table_schema FROM information_schema.tables "
|
||||
"WHERE table_name = 'alembic_version' "
|
||||
"AND table_schema = ANY(:schemas)"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
schemas_with_table = set(row[0] for row in rows)
|
||||
|
||||
# Schemas without the table definitely need migration
|
||||
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
|
||||
|
||||
if not schemas_with_table:
|
||||
return needs_migration
|
||||
|
||||
# Validate schema names before interpolating into SQL
|
||||
for schema in schemas_with_table:
|
||||
if not is_valid_schema_name(schema):
|
||||
raise ValueError(f"Invalid schema name: {schema}")
|
||||
|
||||
# Single query to get every schema's current revision at once.
|
||||
# Use integer tags instead of interpolating schema names into
|
||||
# string literals to avoid quoting issues.
|
||||
schema_list = list(schemas_with_table)
|
||||
union_parts = [
|
||||
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
|
||||
for i, schema in enumerate(schema_list)
|
||||
]
|
||||
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
|
||||
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
|
||||
|
||||
needs_migration.extend(
|
||||
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
|
||||
)
|
||||
|
||||
return needs_migration
|
||||
|
||||
|
||||
def run_migrations_parallel(
|
||||
schemas: list[str],
|
||||
max_workers: int,
|
||||
|
||||
@@ -0,0 +1,29 @@
|
||||
"""code interpreter seed
|
||||
|
||||
Revision ID: 07b98176f1de
|
||||
Revises: 7cb492013621
|
||||
Create Date: 2026-02-23 15:55:07.606784
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "07b98176f1de"
|
||||
down_revision = "7cb492013621"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Seed the single instance of code_interpreter_server
|
||||
# NOTE: There should only exist at most and at minimum 1 code_interpreter_server row
|
||||
op.execute(
|
||||
sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(sa.text("DELETE FROM code_interpreter_server"))
|
||||
@@ -0,0 +1,28 @@
|
||||
"""add scim_username to scim_user_mapping
|
||||
|
||||
Revision ID: 0bb4558f35df
|
||||
Revises: 631fd2504136
|
||||
Create Date: 2026-02-20 10:45:30.340188
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0bb4558f35df"
|
||||
down_revision = "631fd2504136"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_username", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_username")
|
||||
@@ -0,0 +1,37 @@
|
||||
"""add cache_store table
|
||||
|
||||
Revision ID: 2664261bfaab
|
||||
Revises: 4a1e4b1c89d2
|
||||
Create Date: 2026-02-27 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2664261bfaab"
|
||||
down_revision = "4a1e4b1c89d2"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"cache_store",
|
||||
sa.Column("key", sa.String(), nullable=False),
|
||||
sa.Column("value", sa.LargeBinary(), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint("key"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_cache_store_expires",
|
||||
"cache_store",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("expires_at IS NOT NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cache_store_expires", table_name="cache_store")
|
||||
op.drop_table("cache_store")
|
||||
@@ -0,0 +1,51 @@
|
||||
"""Add INDEXING to UserFileStatus
|
||||
|
||||
Revision ID: 4a1e4b1c89d2
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-02-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "4a1e4b1c89d2"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
"""Drop the existing CHECK constraint on user_file.status.
|
||||
|
||||
The constraint name is auto-generated by SQLAlchemy and unknown,
|
||||
so we look it up via the inspector.
|
||||
"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
|
||||
)
|
||||
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -0,0 +1,69 @@
|
||||
"""add python tool on default
|
||||
|
||||
Revision ID: 57122d037335
|
||||
Revises: c0c937d5c9e5
|
||||
Create Date: 2026-02-27 10:10:40.124925
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "57122d037335"
|
||||
down_revision = "c0c937d5c9e5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
PYTHON_TOOL_NAME = "python"
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up the PythonTool id
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
tool_id = result[0]
|
||||
|
||||
# Attach to the default persona (id=0) if not already attached
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO persona__tool (persona_id, tool_id)
|
||||
VALUES (0, :tool_id)
|
||||
ON CONFLICT DO NOTHING
|
||||
"""
|
||||
),
|
||||
{"tool_id": tool_id},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
result = conn.execute(
|
||||
sa.text("SELECT id FROM tool WHERE name = :name"),
|
||||
{"name": PYTHON_TOOL_NAME},
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
DELETE FROM persona__tool
|
||||
WHERE persona_id = 0 AND tool_id = :tool_id
|
||||
"""
|
||||
),
|
||||
{"tool_id": result[0]},
|
||||
)
|
||||
@@ -0,0 +1,32 @@
|
||||
"""add approx_chunk_count_in_vespa to opensearch tenant migration
|
||||
|
||||
Revision ID: 631fd2504136
|
||||
Revises: c7f2e1b4a9d3
|
||||
Create Date: 2026-02-18 21:07:52.831215
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "631fd2504136"
|
||||
down_revision = "c7f2e1b4a9d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"approx_chunk_count_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")
|
||||
@@ -0,0 +1,112 @@
|
||||
"""persona cleanup and featured
|
||||
|
||||
Revision ID: 6b3b4083c5aa
|
||||
Revises: 57122d037335
|
||||
Create Date: 2026-02-26 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b3b4083c5aa"
|
||||
down_revision = "57122d037335"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add featured column with nullable=True first
|
||||
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
|
||||
|
||||
# Migrate data from is_default_persona to featured
|
||||
op.execute("UPDATE persona SET featured = is_default_persona")
|
||||
|
||||
# Make featured non-nullable with default=False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"featured",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
# Drop is_default_persona column
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
|
||||
# Drop unused columns
|
||||
op.drop_column("persona", "num_chunks")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "recency_bias")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back recency_bias column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
sa.VARCHAR(),
|
||||
nullable=False,
|
||||
server_default="base_decay",
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_filter_extraction column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_filter_extraction",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_relevance_filter column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_relevance_filter",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back chunks_below column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back chunks_above column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back num_chunks column
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
|
||||
|
||||
# Add back is_default_persona column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"is_default_persona",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate data from featured to is_default_persona
|
||||
op.execute("UPDATE persona SET is_default_persona = featured")
|
||||
|
||||
# Drop featured column
|
||||
op.drop_column("persona", "featured")
|
||||
@@ -0,0 +1,48 @@
|
||||
"""add enterprise and name fields to scim_user_mapping
|
||||
|
||||
Revision ID: 7616121f6e97
|
||||
Revises: 07b98176f1de
|
||||
Create Date: 2026-02-23 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7616121f6e97"
|
||||
down_revision = "07b98176f1de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("department", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("manager", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("given_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("family_name", sa.String(), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_emails_json", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_emails_json")
|
||||
op.drop_column("scim_user_mapping", "family_name")
|
||||
op.drop_column("scim_user_mapping", "given_name")
|
||||
op.drop_column("scim_user_mapping", "manager")
|
||||
op.drop_column("scim_user_mapping", "department")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""code interpreter server model
|
||||
|
||||
Revision ID: 7cb492013621
|
||||
Revises: 0bb4558f35df
|
||||
Create Date: 2026-02-22 18:54:54.007265
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7cb492013621"
|
||||
down_revision = "0bb4558f35df"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"code_interpreter_server",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column(
|
||||
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("code_interpreter_server")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add needs_persona_sync to user_file
|
||||
|
||||
Revision ID: 8ffcc2bcfc11
|
||||
Revises: 7616121f6e97
|
||||
Create Date: 2026-02-23 10:48:48.343826
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8ffcc2bcfc11"
|
||||
down_revision = "7616121f6e97"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user_file",
|
||||
sa.Column(
|
||||
"needs_persona_sync",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user_file", "needs_persona_sync")
|
||||
@@ -0,0 +1,34 @@
|
||||
"""make scim_user_mapping.external_id nullable
|
||||
|
||||
Revision ID: a3b8d9e2f1c4
|
||||
Revises: 2664261bfaab
|
||||
Create Date: 2026-03-02
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a3b8d9e2f1c4"
|
||||
down_revision = "2664261bfaab"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete any rows where external_id is NULL before re-applying NOT NULL
|
||||
op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL")
|
||||
op.alter_column(
|
||||
"scim_user_mapping",
|
||||
"external_id",
|
||||
nullable=False,
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
"""add hierarchy_node_by_connector_credential_pair table
|
||||
|
||||
Revision ID: b5c4d7e8f9a1
|
||||
Revises: a3b8d9e2f1c4
|
||||
Create Date: 2026-03-04
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "b5c4d7e8f9a1"
|
||||
down_revision = "a3b8d9e2f1c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
sa.Column("hierarchy_node_id", sa.Integer(), nullable=False),
|
||||
sa.Column("connector_id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["hierarchy_node_id"],
|
||||
["hierarchy_node.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["connector_id", "credential_id"],
|
||||
[
|
||||
"connector_credential_pair.connector_id",
|
||||
"connector_credential_pair.credential_id",
|
||||
],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
"hierarchy_node_by_connector_credential_pair",
|
||||
["connector_id", "credential_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_hierarchy_node_cc_pair_connector_credential",
|
||||
table_name="hierarchy_node_by_connector_credential_pair",
|
||||
)
|
||||
op.drop_table("hierarchy_node_by_connector_credential_pair")
|
||||
@@ -0,0 +1,70 @@
|
||||
"""llm provider deprecate fields
|
||||
|
||||
Revision ID: c0c937d5c9e5
|
||||
Revises: 8ffcc2bcfc11
|
||||
Create Date: 2026-02-25 17:35:46.125102
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c0c937d5c9e5"
|
||||
down_revision = "8ffcc2bcfc11"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Make default_model_name nullable (was NOT NULL)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow)
|
||||
op.drop_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
type_="unique",
|
||||
)
|
||||
|
||||
# Remove server_default from is_default_vision_provider (was server_default=false())
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Restore default_model_name to NOT NULL (set empty string for any NULLs first)
|
||||
op.execute(
|
||||
"UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL"
|
||||
)
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"default_model_name",
|
||||
existing_type=sa.String(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Restore unique constraint on is_default_provider
|
||||
op.create_unique_constraint(
|
||||
"llm_provider_is_default_provider_key",
|
||||
"llm_provider",
|
||||
["is_default_provider"],
|
||||
)
|
||||
|
||||
# Restore server_default for is_default_vision_provider
|
||||
op.alter_column(
|
||||
"llm_provider",
|
||||
"is_default_vision_provider",
|
||||
existing_type=sa.Boolean(),
|
||||
server_default=sa.false(),
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add sharing_scope to build_session
|
||||
|
||||
Revision ID: c7f2e1b4a9d3
|
||||
Revises: 19c0ccb01687
|
||||
Create Date: 2026-02-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c7f2e1b4a9d3"
|
||||
down_revision = "19c0ccb01687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"build_session",
|
||||
sa.Column(
|
||||
"sharing_scope",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("build_session", "sharing_scope")
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -22,59 +21,52 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -9,12 +9,15 @@ from onyx.access.access import (
|
||||
_get_access_for_documents as get_access_for_documents_without_groups,
|
||||
)
|
||||
from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups
|
||||
from onyx.access.access import collect_user_file_access
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_group
|
||||
from onyx.db.document import get_document_sources
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -116,6 +119,68 @@ def _get_access_for_documents(
|
||||
return access_map
|
||||
|
||||
|
||||
def _collect_user_file_group_names(user_file: UserFile) -> set[str]:
|
||||
"""Extract user-group names from the already-loaded Persona.groups
|
||||
relationships on a UserFile (skipping deleted personas)."""
|
||||
groups: set[str] = set()
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
for group in persona.groups:
|
||||
groups.add(group.name)
|
||||
return groups
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: extends the MIT user file ACL with user group names
|
||||
from personas shared via user groups.
|
||||
|
||||
Uses a single DB query (via fetch_user_files_with_access_relationships)
|
||||
that eagerly loads both the MIT-needed and EE-needed relationships.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
user_files = fetch_user_files_with_access_relationships(
|
||||
user_file_ids, db_session, eager_load_groups=True
|
||||
)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""EE version: works on pre-loaded UserFile objects.
|
||||
Expects Persona.groups to be eagerly loaded.
|
||||
|
||||
NOTE: is imported in onyx.access.access by `fetch_versioned_implementation`
|
||||
DO NOT REMOVE."""
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
if user_file.user is None:
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
is_public=True,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
continue
|
||||
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
group_names = _collect_user_file_group_names(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=list(group_names),
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _get_acl_for_user(user: User, db_session: Session) -> set[str]:
|
||||
"""Returns a list of ACL entries that the user has access to. This is meant to be
|
||||
used downstream to filter out documents that the user does not have access to. The
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
import jwt
|
||||
@@ -20,7 +21,13 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
# All the Auth flows are valid for EE version, but warn about deprecated 'disabled'
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from onyx.background.celery.apps import app_base
|
||||
from onyx.background.celery.apps.background import celery_app
|
||||
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
"ee.onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"ee.onyx.background.celery.tasks.external_group_syncing",
|
||||
"ee.onyx.background.celery.tasks.cleanup",
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning",
|
||||
"ee.onyx.background.celery.tasks.query_history",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -18,7 +18,7 @@ from onyx.db.models import HierarchyNode
|
||||
|
||||
|
||||
def _build_hierarchy_access_filter(
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> ColumnElement[bool]:
|
||||
"""Build SQLAlchemy filter for hierarchy node access.
|
||||
@@ -43,7 +43,7 @@ def _build_hierarchy_access_filter(
|
||||
def _get_accessible_hierarchy_nodes_for_source(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
user_email: str | None,
|
||||
user_email: str,
|
||||
external_group_ids: list[str],
|
||||
) -> list[HierarchyNode]:
|
||||
"""
|
||||
|
||||
@@ -11,11 +11,10 @@ from ee.onyx.server.license.models import LicenseMetadata
|
||||
from ee.onyx.server.license.models import LicensePayload
|
||||
from ee.onyx.server.license.models import LicenseSource
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.db.models import License
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -142,7 +141,7 @@ def get_used_seats(tenant_id: str | None = None) -> int:
|
||||
|
||||
def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None:
|
||||
"""
|
||||
Get license metadata from Redis cache.
|
||||
Get license metadata from cache.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
@@ -150,38 +149,34 @@ def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata
|
||||
Returns:
|
||||
LicenseMetadata if cached, None otherwise
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_replica_client(tenant_id=tenant)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cached = cache.get(LICENSE_METADATA_KEY)
|
||||
if not cached:
|
||||
return None
|
||||
|
||||
cached = redis_client.get(LICENSE_METADATA_KEY)
|
||||
if cached:
|
||||
try:
|
||||
cached_str: str
|
||||
if isinstance(cached, bytes):
|
||||
cached_str = cached.decode("utf-8")
|
||||
else:
|
||||
cached_str = str(cached)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
return None
|
||||
try:
|
||||
cached_str = (
|
||||
cached.decode("utf-8") if isinstance(cached, bytes) else str(cached)
|
||||
)
|
||||
return LicenseMetadata.model_validate_json(cached_str)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to parse cached license metadata: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def invalidate_license_cache(tenant_id: str | None = None) -> None:
|
||||
"""
|
||||
Invalidate the license metadata cache (not the license itself).
|
||||
|
||||
This deletes the cached LicenseMetadata from Redis. The actual license
|
||||
in the database is not affected. Redis delete is idempotent - if the
|
||||
key doesn't exist, this is a no-op.
|
||||
Deletes the cached LicenseMetadata. The actual license in the database
|
||||
is not affected. Delete is idempotent — if the key doesn't exist, this
|
||||
is a no-op.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID (for multi-tenant deployments)
|
||||
"""
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
redis_client.delete(LICENSE_METADATA_KEY)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
cache.delete(LICENSE_METADATA_KEY)
|
||||
logger.info("License cache invalidated")
|
||||
|
||||
|
||||
@@ -192,7 +187,7 @@ def update_license_cache(
|
||||
tenant_id: str | None = None,
|
||||
) -> LicenseMetadata:
|
||||
"""
|
||||
Update the Redis cache with license metadata.
|
||||
Update the cache with license metadata.
|
||||
|
||||
We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because:
|
||||
1. Frontend needs status to show appropriate UI/banners
|
||||
@@ -211,7 +206,7 @@ def update_license_cache(
|
||||
from ee.onyx.utils.license import get_license_status
|
||||
|
||||
tenant = tenant_id or get_current_tenant_id()
|
||||
redis_client = get_redis_client(tenant_id=tenant)
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
|
||||
used_seats = get_used_seats(tenant)
|
||||
status = get_license_status(payload, grace_period_end)
|
||||
@@ -230,7 +225,7 @@ def update_license_cache(
|
||||
stripe_subscription_id=payload.stripe_subscription_id,
|
||||
)
|
||||
|
||||
redis_client.set(
|
||||
cache.set(
|
||||
LICENSE_METADATA_KEY,
|
||||
metadata.model_dump_json(),
|
||||
ex=LICENSE_CACHE_TTL_SECONDS,
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.persona import mark_persona_user_files_for_sync
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
|
||||
|
||||
@@ -26,7 +27,9 @@ def update_persona_access(
|
||||
|
||||
NOTE: Callers are responsible for committing."""
|
||||
|
||||
needs_sync = False
|
||||
if is_public is not None:
|
||||
needs_sync = True
|
||||
persona = db_session.query(Persona).filter(Persona.id == persona_id).first()
|
||||
if persona:
|
||||
persona.is_public = is_public
|
||||
@@ -35,6 +38,7 @@ def update_persona_access(
|
||||
# and a non-empty list means "replace with these shares".
|
||||
|
||||
if user_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -54,6 +58,7 @@ def update_persona_access(
|
||||
)
|
||||
|
||||
if group_ids is not None:
|
||||
needs_sync = True
|
||||
db_session.query(Persona__UserGroup).filter(
|
||||
Persona__UserGroup.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -63,3 +68,7 @@ def update_persona_access(
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
# When sharing changes, user file ACLs need to be updated in the vector DB
|
||||
if needs_sync:
|
||||
mark_persona_user_files_for_sync(persona_id, db_session)
|
||||
|
||||
@@ -34,6 +34,7 @@ from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
@@ -125,11 +126,27 @@ class ScimDAL(DAL):
|
||||
|
||||
def create_user_mapping(
|
||||
self,
|
||||
external_id: str,
|
||||
external_id: str | None,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
|
||||
"""Create a SCIM mapping for a user.
|
||||
|
||||
``external_id`` may be ``None`` when the IdP omits it (RFC 7643
|
||||
allows this). The mapping still marks the user as SCIM-managed.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
department=f.department,
|
||||
manager=f.manager,
|
||||
given_name=f.given_name,
|
||||
family_name=f.family_name,
|
||||
scim_emails_json=f.scim_emails_json,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
@@ -248,17 +265,22 @@ class ScimDAL(DAL):
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, str | None]], int]:
|
||||
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, external_id) pairs, total_count).
|
||||
A tuple of (list of (user, mapping) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
"""
|
||||
query = select(User).where(
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER])
|
||||
# Inner-join with ScimUserMapping so only SCIM-managed users appear.
|
||||
# Pre-existing system accounts (anonymous, admin, etc.) are excluded
|
||||
# unless they were explicitly linked via SCIM provisioning.
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -292,33 +314,120 @@ class ScimDAL(DAL):
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
).all()
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Batch-fetch external IDs to avoid N+1 queries
|
||||
ext_id_map = self._get_user_external_ids([u.id for u in users])
|
||||
return [(u, ext_id_map.get(u.id)) for u in users], total
|
||||
# Batch-fetch SCIM mappings to avoid N+1 queries
|
||||
mapping_map = self._get_user_mappings_batch([u.id for u in users])
|
||||
return [(u, mapping_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
def sync_user_external_id(
|
||||
self,
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> None:
|
||||
"""Sync the SCIM mapping for a user.
|
||||
|
||||
If a mapping already exists, its fields are updated (including
|
||||
setting ``external_id`` to ``None`` when the IdP omits it).
|
||||
If no mapping exists and ``new_external_id`` is provided, a new
|
||||
mapping is created. A mapping is never deleted here — SCIM-managed
|
||||
users must retain their mapping to remain visible in ``GET /Users``.
|
||||
|
||||
When *fields* is provided, all mapping fields are written
|
||||
unconditionally — including ``None`` values — so that a caller can
|
||||
clear a previously-set field (e.g. removing a department).
|
||||
"""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
else:
|
||||
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
if fields is not None:
|
||||
mapping.department = fields.department
|
||||
mapping.manager = fields.manager
|
||||
mapping.given_name = fields.given_name
|
||||
mapping.family_name = fields.family_name
|
||||
mapping.scim_emails_json = fields.scim_emails_json
|
||||
elif new_external_id:
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
|
||||
"""Batch-fetch external IDs for a list of user IDs."""
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, ScimUserMapping]:
|
||||
"""Batch-fetch SCIM user mappings keyed by user ID."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m.external_id for m in mappings}
|
||||
return {m.user_id: m for m in mappings}
|
||||
|
||||
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
|
||||
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
|
||||
|
||||
Excludes groups marked for deletion.
|
||||
"""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
|
||||
).all()
|
||||
|
||||
group_ids = [r.user_group_id for r in rels]
|
||||
if not group_ids:
|
||||
return []
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
return [(g.id, g.name) for g in groups]
|
||||
|
||||
def get_users_groups_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Batch-fetch group memberships for multiple users.
|
||||
|
||||
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
|
||||
Avoids N+1 queries when building user list responses.
|
||||
"""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
group_ids = list({r.user_group_id for r in rels})
|
||||
if not group_ids:
|
||||
return {}
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
groups_by_id = {g.id: g.name for g in groups}
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {}
|
||||
for r in rels:
|
||||
if r.user_id and r.user_group_id in groups_by_id:
|
||||
result.setdefault(r.user_id, []).append(
|
||||
(r.user_group_id, groups_by_id[r.user_group_id])
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
@@ -483,9 +592,13 @@ class ScimDAL(DAL):
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
@@ -504,9 +617,13 @@ class ScimDAL(DAL):
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
existing_users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
|
||||
@@ -9,20 +9,26 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
from ee.onyx.server.user_group.models import UserGroupCreate
|
||||
from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
@@ -195,8 +201,60 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def _add_user_group_snapshot_eager_loads(
|
||||
stmt: Select,
|
||||
) -> Select:
|
||||
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
|
||||
return stmt.options(
|
||||
selectinload(UserGroup.users),
|
||||
selectinload(UserGroup.user_group_relationships),
|
||||
selectinload(UserGroup.cc_pair_relationships)
|
||||
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
|
||||
.options(
|
||||
selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(ConnectorCredentialPair.credential).selectinload(
|
||||
Credential.user
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.personas).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -209,6 +267,8 @@ def fetch_user_groups(
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -216,11 +276,16 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session, user_id: UUID, only_curator_groups: bool = False
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -230,7 +295,9 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def construct_document_id_select_by_usergroup(
|
||||
@@ -405,7 +472,9 @@ def _add_user_group__cc_pair_relationships__no_commit(
|
||||
|
||||
def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup:
|
||||
db_user_group = UserGroup(
|
||||
name=user_group.name, time_last_modified_by_user=func.now()
|
||||
name=user_group.name,
|
||||
time_last_modified_by_user=func.now(),
|
||||
is_up_to_date=DISABLE_VECTOR_DB,
|
||||
)
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
@@ -708,8 +777,7 @@ def update_user_group(
|
||||
cc_pair_ids=user_group_update.cc_pair_ids,
|
||||
)
|
||||
|
||||
# only needs to sync with Vespa if the cc_pairs have been updated
|
||||
if cc_pairs_updated:
|
||||
if cc_pairs_updated and not DISABLE_VECTOR_DB:
|
||||
db_user_group.is_up_to_date = False
|
||||
|
||||
removed_users = db_session.scalars(
|
||||
|
||||
@@ -68,6 +68,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
company_domain: str,
|
||||
retriever_drive_service: GoogleDriveService | None,
|
||||
admin_drive_service: GoogleDriveService,
|
||||
fallback_user_email: str,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
@@ -79,6 +80,11 @@ def get_external_access_for_raw_gdrive_file(
|
||||
set add_prefix to True so group IDs are prefixed with the source type.
|
||||
When invoked from doc_sync (permission sync), use the default (False)
|
||||
since upsert_document_external_perms handles prefixing.
|
||||
fallback_user_email: When we cannot retrieve any permission info for a file
|
||||
(e.g. externally-owned files where the API returns no permissions
|
||||
and permissions.list returns 403), fall back to granting access
|
||||
to this user. This is typically the impersonated org user whose
|
||||
drive contained the file.
|
||||
"""
|
||||
doc_id = file.get("id")
|
||||
if not doc_id:
|
||||
@@ -117,6 +123,26 @@ def get_external_access_for_raw_gdrive_file(
|
||||
[permissions_list, backup_permissions_list]
|
||||
)
|
||||
|
||||
# For externally-owned files, the Drive API may return no permissions
|
||||
# and permissions.list may return 403. In this case, fall back to
|
||||
# granting access to the user who found the file in their drive.
|
||||
# Note, even if other users also have access to this file,
|
||||
# they will not be granted access in Onyx.
|
||||
# We check permissions_list (the final result after all fetch attempts)
|
||||
# rather than the raw fields, because permission_ids may be present
|
||||
# but the actual fetch can still return empty due to a 403.
|
||||
if not permissions_list:
|
||||
logger.info(
|
||||
f"No permission info available for file {doc_id} "
|
||||
f"(likely owned by a user outside of your organization). "
|
||||
f"Falling back to granting access to retriever user: {fallback_user_email}"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails={fallback_user_email},
|
||||
external_user_group_ids=set(),
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
folder_ids_to_inherit_permissions_from: set[str] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from jira import JIRA
|
||||
from jira.exceptions import JIRAError
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.connectors.jira.utils import build_jira_client
|
||||
@@ -9,107 +11,102 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_ATLASSIAN_ACCOUNT_TYPE = "atlassian"
|
||||
_GROUP_MEMBER_PAGE_SIZE = 50
|
||||
|
||||
def _get_jira_group_members_email(
|
||||
# The GET /group/member endpoint was introduced in Jira 6.0.
|
||||
# Jira versions older than 6.0 do not have group management REST APIs at all.
|
||||
_MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0"
|
||||
|
||||
|
||||
def _fetch_group_member_page(
|
||||
jira_client: JIRA,
|
||||
group_name: str,
|
||||
) -> list[str]:
|
||||
"""Get all member emails for a Jira group.
|
||||
start_at: int,
|
||||
) -> dict[str, Any]:
|
||||
"""Fetch a single page from the non-deprecated GET /group/member endpoint.
|
||||
|
||||
Filters out app accounts (bots, integrations) and only returns real user emails.
|
||||
The old GET /group endpoint (used by jira_client.group_members()) is deprecated
|
||||
and decommissioned in Jira Server 10.3+. This uses the replacement endpoint
|
||||
directly via the library's internal _get_json helper, following the same pattern
|
||||
as enhanced_search_ids / bulk_fetch_issues in connector.py.
|
||||
|
||||
There is an open PR to the library to switch to this endpoint since last year:
|
||||
https://github.com/pycontribs/jira/pull/2356
|
||||
so once it is merged and released, we can switch to using the library function.
|
||||
"""
|
||||
emails: list[str] = []
|
||||
|
||||
try:
|
||||
# group_members returns an OrderedDict of account_id -> member_info
|
||||
members = jira_client.group_members(group=group_name)
|
||||
|
||||
if not members:
|
||||
logger.warning(f"No members found for group {group_name}")
|
||||
return emails
|
||||
|
||||
for account_id, member_info in members.items():
|
||||
# member_info is a dict with keys like 'fullname', 'email', 'active'
|
||||
email = member_info.get("email")
|
||||
|
||||
# Skip "hidden" emails - these are typically app accounts
|
||||
if email and email != "hidden":
|
||||
emails.append(email)
|
||||
else:
|
||||
# For cloud, we might need to fetch user details separately
|
||||
try:
|
||||
user = jira_client.user(id=account_id)
|
||||
|
||||
# Skip app accounts (bots, integrations, etc.)
|
||||
if hasattr(user, "accountType") and user.accountType == "app":
|
||||
logger.info(
|
||||
f"Skipping app account {account_id} for group {group_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if hasattr(user, "emailAddress") and user.emailAddress:
|
||||
emails.append(user.emailAddress)
|
||||
else:
|
||||
logger.warning(f"User {account_id} has no email address")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch email for user {account_id} in group {group_name}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
|
||||
return emails
|
||||
return jira_client._get_json(
|
||||
"group/member",
|
||||
params={
|
||||
"groupname": group_name,
|
||||
"includeInactiveUsers": "false",
|
||||
"startAt": start_at,
|
||||
"maxResults": _GROUP_MEMBER_PAGE_SIZE,
|
||||
},
|
||||
)
|
||||
except JIRAError as e:
|
||||
if e.status_code == 404:
|
||||
raise RuntimeError(
|
||||
f"GET /group/member returned 404 for group '{group_name}'. "
|
||||
f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. "
|
||||
f"If you are running a self-hosted Jira instance, please upgrade "
|
||||
f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}."
|
||||
) from e
|
||||
raise
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
def _get_group_member_emails(
|
||||
jira_client: JIRA,
|
||||
) -> dict[str, set[str]]:
|
||||
"""Build a map of group names to member emails."""
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get all member emails for a single Jira group.
|
||||
|
||||
try:
|
||||
# Get all groups from Jira - returns a list of group name strings
|
||||
group_names = jira_client.groups()
|
||||
Uses the non-deprecated GET /group/member endpoint which returns full user
|
||||
objects including accountType, so we can filter out app/customer accounts
|
||||
without making separate user() calls.
|
||||
"""
|
||||
emails: set[str] = set()
|
||||
start_at = 0
|
||||
|
||||
if not group_names:
|
||||
logger.warning("No groups found in Jira")
|
||||
return group_member_emails
|
||||
while True:
|
||||
try:
|
||||
page = _fetch_group_member_page(jira_client, group_name, start_at)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching members for group {group_name}: {e}")
|
||||
raise
|
||||
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
members: list[dict[str, Any]] = page.get("values", [])
|
||||
for member in members:
|
||||
account_type = member.get("accountType")
|
||||
# On Jira DC < 9.0, accountType is absent; include those users.
|
||||
# On Cloud / DC 9.0+, filter to real user accounts only.
|
||||
if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE:
|
||||
continue
|
||||
|
||||
member_emails = _get_jira_group_members_email(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
|
||||
if member_emails:
|
||||
group_member_emails[group_name] = set(member_emails)
|
||||
logger.debug(
|
||||
f"Found {len(member_emails)} members for group {group_name}"
|
||||
)
|
||||
email = member.get("emailAddress")
|
||||
if email:
|
||||
emails.add(email)
|
||||
else:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
logger.warning(
|
||||
f"Atlassian user {member.get('accountId', 'unknown')} "
|
||||
f"in group {group_name} has no visible email address"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building group member email map: {e}")
|
||||
if page.get("isLast", True) or not members:
|
||||
break
|
||||
start_at += len(members)
|
||||
|
||||
return group_member_emails
|
||||
return emails
|
||||
|
||||
|
||||
def jira_group_sync(
|
||||
tenant_id: str, # noqa: ARG001
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""
|
||||
Sync Jira groups and their members.
|
||||
"""Sync Jira groups and their members, yielding one group at a time.
|
||||
|
||||
This function fetches all groups from Jira and yields ExternalUserGroup
|
||||
objects containing the group ID and member emails.
|
||||
Streams group-by-group rather than accumulating all groups in memory.
|
||||
"""
|
||||
jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "")
|
||||
scoped_token = cc_pair.connector.connector_specific_config.get(
|
||||
@@ -130,12 +127,26 @@ def jira_group_sync(
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(jira_client=jira_client)
|
||||
if not group_member_email_map:
|
||||
raise ValueError(f"No groups with members found for cc_pair_id={cc_pair.id}")
|
||||
group_names = jira_client.groups()
|
||||
if not group_names:
|
||||
raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}")
|
||||
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
yield ExternalUserGroup(
|
||||
id=group_id,
|
||||
user_emails=list(group_member_emails),
|
||||
logger.info(f"Found {len(group_names)} groups in Jira")
|
||||
|
||||
for group_name in group_names:
|
||||
if not group_name:
|
||||
continue
|
||||
|
||||
member_emails = _get_group_member_emails(
|
||||
jira_client=jira_client,
|
||||
group_name=group_name,
|
||||
)
|
||||
if not member_emails:
|
||||
logger.debug(f"No members found for group {group_name}")
|
||||
continue
|
||||
|
||||
logger.debug(f"Found {len(member_emails)} members for group {group_name}")
|
||||
yield ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(member_emails),
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import deque
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests as _requests
|
||||
@@ -69,7 +68,7 @@ def _graph_api_get(
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except (_requests.ConnectionError, _requests.Timeout):
|
||||
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
@@ -598,8 +597,12 @@ def get_external_access_from_sharepoint(
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
# Keep percent-encoding intact so the path matches the encoding
|
||||
# used by the Office365 library's SPResPath.create_relative(),
|
||||
# which compares against urlparse(context.base_url).path.
|
||||
# Decoding (e.g. %27 → ') causes a mismatch that duplicates
|
||||
# the site prefix in the constructed URL.
|
||||
server_relative_url = urlparse(site_url).path
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
from ee.onyx.server.billing.api import router as billing_router
|
||||
@@ -31,6 +30,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import register_scim_exception_handlers
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
@@ -152,12 +152,9 @@ def get_application() -> FastAPI:
|
||||
# License management
|
||||
include_router_with_global_prefix_prepended(application, license_router)
|
||||
|
||||
# Unified billing API - available when license system is enabled
|
||||
# Works for both self-hosted and cloud deployments
|
||||
# TODO(ENG-3533): Once frontend migrates to /admin/billing/*, this becomes the
|
||||
# primary billing API and /tenants/* billing endpoints can be removed
|
||||
if LICENSE_ENFORCEMENT_ENABLED:
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
# Unified billing API - always registered in EE.
|
||||
# Each endpoint is protected by the `current_admin_user` dependency (admin auth).
|
||||
include_router_with_global_prefix_prepended(application, billing_router)
|
||||
|
||||
if MULTI_TENANT:
|
||||
# Tenant management
|
||||
@@ -167,6 +164,7 @@ def get_application() -> FastAPI:
|
||||
# they use their own SCIM bearer token auth).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
register_scim_exception_handlers(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
@@ -26,7 +26,6 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -42,7 +41,6 @@ from ee.onyx.server.billing.models import SeatUpdateRequest
|
||||
from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import StripePublishableKeyResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.billing.service import BillingServiceError
|
||||
from ee.onyx.server.billing.service import (
|
||||
create_checkout_session as create_checkout_service,
|
||||
)
|
||||
@@ -58,6 +56,8 @@ from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -169,26 +169,23 @@ async def create_checkout_session(
|
||||
if seats is not None:
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if seats < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot subscribe with fewer seats than current usage. "
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot subscribe with fewer seats than current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {seats} seats.",
|
||||
)
|
||||
|
||||
# Build redirect URL for after checkout completion
|
||||
redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success"
|
||||
|
||||
try:
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
return await create_checkout_service(
|
||||
billing_period=billing_period,
|
||||
seats=seats,
|
||||
email=email,
|
||||
license_data=license_data,
|
||||
redirect_url=redirect_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -206,18 +203,15 @@ async def create_customer_portal_session(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
|
||||
return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing"
|
||||
|
||||
try:
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
return await create_portal_service(
|
||||
license_data=license_data,
|
||||
return_url=return_url,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
@@ -240,9 +234,9 @@ async def get_billing_information(
|
||||
|
||||
# Check circuit breaker (self-hosted only)
|
||||
if _is_billing_circuit_open():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE,
|
||||
"Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -250,11 +244,15 @@ async def get_billing_information(
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except BillingServiceError as e:
|
||||
except OnyxError as e:
|
||||
# Open circuit breaker on connection failures (self-hosted only)
|
||||
if e.status_code in (502, 503, 504):
|
||||
if e.status_code in (
|
||||
OnyxErrorCode.BAD_GATEWAY.status_code,
|
||||
OnyxErrorCode.SERVICE_UNAVAILABLE.status_code,
|
||||
OnyxErrorCode.GATEWAY_TIMEOUT.status_code,
|
||||
):
|
||||
_open_billing_circuit()
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/seats/update")
|
||||
@@ -274,31 +272,25 @@ async def update_seats(
|
||||
|
||||
# Self-hosted requires license
|
||||
if not MULTI_TENANT and not license_data:
|
||||
raise HTTPException(status_code=400, detail="No license found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found")
|
||||
|
||||
# Validate that new seat count is not less than current used seats
|
||||
used_seats = get_used_seats(tenant_id)
|
||||
if request.new_seat_count < used_seats:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Cannot reduce seats below current usage. "
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
f"Cannot reduce seats below current usage. "
|
||||
f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.",
|
||||
)
|
||||
|
||||
try:
|
||||
result = await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
|
||||
return result
|
||||
except BillingServiceError as e:
|
||||
raise HTTPException(status_code=e.status_code, detail=e.message)
|
||||
# Note: Don't store license here - the control plane may still be processing
|
||||
# the subscription update. The frontend should call /license/claim after a
|
||||
# short delay to get the freshly generated license.
|
||||
return await update_seat_service(
|
||||
new_seat_count=request.new_seat_count,
|
||||
license_data=license_data,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -329,18 +321,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -351,17 +343,17 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from ee.onyx.server.billing.models import SeatUpdateResponse
|
||||
from ee.onyx.server.billing.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -31,15 +33,6 @@ logger = setup_logger()
|
||||
_REQUEST_TIMEOUT = 30.0
|
||||
|
||||
|
||||
class BillingServiceError(Exception):
|
||||
"""Exception raised for billing service errors."""
|
||||
|
||||
def __init__(self, message: str, status_code: int = 500):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def _get_proxy_headers(license_data: str | None) -> dict[str, str]:
|
||||
"""Build headers for proxy requests (self-hosted).
|
||||
|
||||
@@ -101,7 +94,7 @@ async def _make_billing_request(
|
||||
Response JSON as dict
|
||||
|
||||
Raises:
|
||||
BillingServiceError: If request fails
|
||||
OnyxError: If request fails
|
||||
"""
|
||||
|
||||
base_url = _get_base_url()
|
||||
@@ -128,11 +121,17 @@ async def _make_billing_request(
|
||||
except Exception:
|
||||
pass
|
||||
logger.error(f"{error_message}: {e.response.status_code} - {detail}")
|
||||
raise BillingServiceError(detail, e.response.status_code)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
detail,
|
||||
status_code_override=e.response.status_code,
|
||||
)
|
||||
|
||||
except httpx.RequestError:
|
||||
logger.exception("Failed to connect to billing service")
|
||||
raise BillingServiceError("Failed to connect to billing service", 502)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service"
|
||||
)
|
||||
|
||||
|
||||
async def create_checkout_session(
|
||||
|
||||
@@ -223,6 +223,15 @@ def get_active_scim_token(
|
||||
token = dal.get_active_token()
|
||||
if not token:
|
||||
raise HTTPException(status_code=404, detail="No active SCIM token")
|
||||
|
||||
# Derive the IdP domain from the first synced user as a heuristic.
|
||||
idp_domain: str | None = None
|
||||
mappings, _total = dal.list_user_mappings(start_index=1, count=1)
|
||||
if mappings:
|
||||
user = dal.get_user(mappings[0].user_id)
|
||||
if user and "@" in user.email:
|
||||
idp_domain = user.email.rsplit("@", 1)[1]
|
||||
|
||||
return ScimTokenResponse(
|
||||
id=token.id,
|
||||
name=token.name,
|
||||
@@ -230,6 +239,7 @@ def get_active_scim_token(
|
||||
is_active=token.is_active,
|
||||
created_at=token.created_at,
|
||||
last_used_at=token.last_used_at,
|
||||
idp_domain=idp_domain,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -35,6 +34,8 @@ from ee.onyx.server.license.models import SeatUsageResponse
|
||||
from ee.onyx.utils.license import verify_license_signature
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -127,9 +128,9 @@ async def claim_license(
|
||||
2. Without session_id: Re-claim using existing license for auth
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License claiming is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License claiming is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -146,15 +147,16 @@ async def claim_license(
|
||||
# Re-claim using existing license for auth
|
||||
metadata = get_license_metadata(db_session)
|
||||
if not metadata or not metadata.tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No license found. Provide session_id after checkout.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found. Provide session_id after checkout.",
|
||||
)
|
||||
|
||||
license_row = get_license(db_session)
|
||||
if not license_row or not license_row.license_data:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No license found in database"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"No license found in database",
|
||||
)
|
||||
|
||||
url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}"
|
||||
@@ -173,7 +175,7 @@ async def claim_license(
|
||||
license_data = data.get("license")
|
||||
|
||||
if not license_data:
|
||||
raise HTTPException(status_code=404, detail="No license in response")
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response")
|
||||
|
||||
# Verify signature before persisting
|
||||
payload = verify_license_signature(license_data)
|
||||
@@ -199,12 +201,14 @@ async def claim_license(
|
||||
detail = error_data.get("detail", detail)
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=status_code, detail=detail)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
except requests.RequestException:
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Failed to connect to license server"
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server"
|
||||
)
|
||||
|
||||
|
||||
@@ -221,9 +225,9 @@ async def upload_license(
|
||||
The license file must be cryptographically signed by Onyx.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License upload is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License upload is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -234,14 +238,14 @@ async def upload_license(
|
||||
# Remove any stray whitespace/newlines from user input
|
||||
license_data = license_data.strip()
|
||||
except UnicodeDecodeError:
|
||||
raise HTTPException(status_code=400, detail="Invalid license file format")
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format")
|
||||
|
||||
# Verify cryptographic signature - this is the only validation needed
|
||||
# The license's tenant_id identifies the customer in control plane, not locally
|
||||
try:
|
||||
payload = verify_license_signature(license_data)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e))
|
||||
|
||||
# Persist to DB and update cache
|
||||
upsert_license(db_session, license_data)
|
||||
@@ -297,9 +301,9 @@ async def delete_license(
|
||||
Admin only - removes license from database and invalidates cache.
|
||||
"""
|
||||
if MULTI_TENANT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="License deletion is only available for self-hosted deployments",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.VALIDATION_ERROR,
|
||||
"License deletion is only available for self-hosted deployments",
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -46,7 +46,6 @@ from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from redis.exceptions import RedisError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
@@ -56,6 +55,7 @@ from ee.onyx.configs.license_enforcement_config import (
|
||||
)
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -164,9 +164,9 @@ def add_license_enforcement_middleware(
|
||||
"[license_enforcement] No license, allowing community features"
|
||||
)
|
||||
is_gated = False
|
||||
except RedisError as e:
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
logger.warning(f"Failed to check license metadata: {e}")
|
||||
# Fail open - don't block users due to Redis connectivity issues
|
||||
# Fail open - don't block users due to cache connectivity issues
|
||||
is_gated = False
|
||||
|
||||
if is_gated:
|
||||
|
||||
@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
num_hits: int = 30
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
@@ -15,7 +15,9 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi_users.password import PasswordHelper
|
||||
@@ -24,23 +26,26 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import ScimAuthError
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.base import serialize_emails
|
||||
from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -48,21 +53,61 @@ from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
|
||||
media_type = "application/scim+json"
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
|
||||
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def register_scim_exception_handlers(app: FastAPI) -> None:
|
||||
"""Register SCIM-specific exception handlers on the FastAPI app.
|
||||
|
||||
Call this after ``app.include_router(scim_router)`` so that auth
|
||||
failures from ``verify_scim_token`` return RFC 7644 §3.12 error
|
||||
envelopes (with ``schemas`` and ``status`` fields) instead of
|
||||
FastAPI's default ``{"detail": "..."}`` format.
|
||||
"""
|
||||
|
||||
@app.exception_handler(ScimAuthError)
|
||||
async def _handle_scim_auth_error(
|
||||
_request: Request, exc: ScimAuthError
|
||||
) -> ScimJSONResponse:
|
||||
return _scim_error_response(exc.status_code, exc.detail)
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
"""Resolve the SCIM provider for the current request.
|
||||
|
||||
Currently returns OktaProvider for all requests. When multi-provider
|
||||
support is added (ENG-3652), this will resolve based on token metadata
|
||||
or tenant configuration — no endpoint changes required.
|
||||
"""
|
||||
return get_default_provider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -75,15 +120,39 @@ def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
def get_resource_types() -> ScimJSONResponse:
|
||||
"""List available SCIM resource types (RFC 7643 §6).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(resources),
|
||||
"Resources": [
|
||||
r.model_dump(exclude_none=True, by_alias=True) for r in resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
def get_schemas() -> ScimJSONResponse:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7).
|
||||
|
||||
Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs
|
||||
like Entra ID expect a JSON object, not a bare array.
|
||||
"""
|
||||
schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF]
|
||||
return ScimJSONResponse(
|
||||
content={
|
||||
"schemas": [SCIM_LIST_RESPONSE_SCHEMA],
|
||||
"totalResults": len(schemas),
|
||||
"Resources": [s.model_dump(exclude_none=True) for s in schemas],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -91,35 +160,43 @@ def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
def _scim_error_response(status: int, detail: str) -> ScimJSONResponse:
|
||||
"""Build a SCIM-compliant error response (RFC 7644 §3.12)."""
|
||||
logger.warning("SCIM error response: status=%s detail=%s", status, detail)
|
||||
body = ScimError(status=str(status), detail=detail)
|
||||
return JSONResponse(
|
||||
return ScimJSONResponse(
|
||||
status_code=status,
|
||||
content=body.model_dump(exclude_none=True),
|
||||
)
|
||||
|
||||
|
||||
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
|
||||
"""Convert an Onyx User to a SCIM User resource representation."""
|
||||
name = None
|
||||
if user.personal_name:
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
name = ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
def _parse_excluded_attributes(raw: str | None) -> set[str]:
|
||||
"""Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5).
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=user.email,
|
||||
name=name,
|
||||
emails=[ScimEmail(value=user.email, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
Returns a set of lowercased attribute names to omit from responses.
|
||||
"""
|
||||
if not raw:
|
||||
return set()
|
||||
return {attr.strip().lower() for attr in raw.split(",") if attr.strip()}
|
||||
|
||||
|
||||
def _apply_exclusions(
|
||||
resource: ScimUserResource | ScimGroupResource,
|
||||
excluded: set[str],
|
||||
) -> dict:
|
||||
"""Serialize a SCIM resource, omitting attributes the IdP excluded.
|
||||
|
||||
RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails``
|
||||
to reduce response payload size. We strip those fields after serialization
|
||||
so the rest of the pipeline doesn't need to know about them.
|
||||
"""
|
||||
data = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
for attr in excluded:
|
||||
# Match case-insensitively against the camelCase field names
|
||||
keys_to_remove = [k for k in data if k.lower() == attr]
|
||||
for k in keys_to_remove:
|
||||
del data[k]
|
||||
return data
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
@@ -135,7 +212,7 @@ def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | JSONResponse:
|
||||
def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse:
|
||||
"""Parse *user_id* as UUID, look up the user, or return a 404 error."""
|
||||
try:
|
||||
uid = UUID(user_id)
|
||||
@@ -155,8 +232,94 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
return name.formatted or " ".join(
|
||||
part for part in [name.givenName, name.familyName] if part
|
||||
# If the client explicitly provides ``formatted``, prefer it — the client
|
||||
# knows what display string it wants. Otherwise build from components.
|
||||
if name.formatted:
|
||||
return name.formatted
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or None
|
||||
|
||||
|
||||
def _scim_resource_response(
|
||||
resource: ScimUserResource | ScimGroupResource | ScimListResponse,
|
||||
status_code: int = 200,
|
||||
) -> ScimJSONResponse:
|
||||
"""Serialize a SCIM resource as ``application/scim+json``."""
|
||||
content = resource.model_dump(exclude_none=True, by_alias=True)
|
||||
return ScimJSONResponse(
|
||||
status_code=status_code,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _build_list_response(
|
||||
resources: list[ScimUserResource | ScimGroupResource],
|
||||
total: int,
|
||||
start_index: int,
|
||||
count: int,
|
||||
excluded: set[str] | None = None,
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""Build a SCIM list response, optionally applying attribute exclusions.
|
||||
|
||||
RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via
|
||||
the ``excludedAttributes`` query parameter.
|
||||
"""
|
||||
if excluded:
|
||||
envelope = ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
)
|
||||
data = envelope.model_dump(exclude_none=True)
|
||||
data["Resources"] = [_apply_exclusions(r, excluded) for r in resources]
|
||||
return ScimJSONResponse(content=data)
|
||||
|
||||
return _scim_resource_response(
|
||||
ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=start_index,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _extract_enterprise_fields(
|
||||
resource: ScimUserResource,
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Extract department and manager from enterprise extension."""
|
||||
ext = resource.enterprise_extension
|
||||
if not ext:
|
||||
return None, None
|
||||
department = ext.department
|
||||
manager = ext.manager.value if ext.manager else None
|
||||
return department, manager
|
||||
|
||||
|
||||
def _mapping_to_fields(
|
||||
mapping: ScimUserMapping | None,
|
||||
) -> ScimMappingFields | None:
|
||||
"""Extract round-trip fields from a SCIM user mapping."""
|
||||
if not mapping:
|
||||
return None
|
||||
return ScimMappingFields(
|
||||
department=mapping.department,
|
||||
manager=mapping.manager,
|
||||
given_name=mapping.given_name,
|
||||
family_name=mapping.family_name,
|
||||
scim_emails_json=mapping.scim_emails_json,
|
||||
)
|
||||
|
||||
|
||||
def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields:
|
||||
"""Build mapping fields from an incoming SCIM user resource."""
|
||||
department, manager = _extract_enterprise_fields(resource)
|
||||
return ScimMappingFields(
|
||||
department=department,
|
||||
manager=manager,
|
||||
given_name=resource.name.givenName if resource.name else None,
|
||||
family_name=resource.name.familyName if resource.name else None,
|
||||
scim_emails_json=serialize_emails(resource.emails),
|
||||
)
|
||||
|
||||
|
||||
@@ -168,14 +331,17 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
@scim_router.get("/Users", response_model=None)
|
||||
def list_users(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -183,68 +349,137 @@ def list_users(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
|
||||
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Users/{user_id}", response_model=None)
|
||||
def get_user(
|
||||
user_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return _user_to_scim(user, mapping.external_id if mapping else None)
|
||||
|
||||
resource = provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
fields=_mapping_to_fields(mapping),
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip().lower()
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
# hitting a 409 conflict — so we require it up front.
|
||||
if not user_resource.externalId:
|
||||
return _scim_error_response(400, "externalId is required")
|
||||
# Check for existing user — if they exist but aren't SCIM-managed yet,
|
||||
# link them to the IdP rather than rejecting with 409.
|
||||
external_id: str | None = user_resource.externalId
|
||||
scim_username: str = user_resource.userName.strip()
|
||||
fields: ScimMappingFields = _fields_from_resource(user_resource)
|
||||
|
||||
# Enforce seat limit
|
||||
existing_user = dal.get_user_by_email(email)
|
||||
if existing_user:
|
||||
existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id)
|
||||
if existing_mapping:
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Adopt pre-existing user into SCIM management.
|
||||
# Reactivating a deactivated user consumes a seat, so enforce the
|
||||
# seat limit the same way replace_user does.
|
||||
if user_resource.active and not existing_user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
dal.update_user(
|
||||
existing_user,
|
||||
is_active=user_resource.active,
|
||||
**({"personal_name": personal_name} if personal_name else {}),
|
||||
)
|
||||
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=existing_user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
existing_user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
# Only enforce seat limit for net-new users — adopting a pre-existing
|
||||
# user doesn't consume a new seat.
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Check for existing user
|
||||
if dal.get_user_by_email(email):
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create user with a random password (SCIM users authenticate via IdP)
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
user = User(
|
||||
@@ -262,13 +497,31 @@ def create_user(
|
||||
dal.rollback()
|
||||
return _scim_error_response(409, f"User with email {email} already exists")
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
dal.create_user_mapping(external_id=external_id, user_id=user.id)
|
||||
# Always create a SCIM mapping so that the user is marked as
|
||||
# SCIM-managed. externalId may be None (RFC 7643 says it's optional).
|
||||
try:
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id,
|
||||
user_id=user.id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, external_id)
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -276,14 +529,15 @@ def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
@@ -293,19 +547,36 @@ def replace_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip().lower(),
|
||||
email=user_resource.userName.strip(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=_scim_name_to_str(user_resource.name),
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
dal.sync_user_external_id(user.id, new_external_id)
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
new_external_id,
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, new_external_id)
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
@@ -313,8 +584,9 @@ def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
) -> ScimUserResource | ScimJSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
|
||||
This is the primary endpoint for user deprovisioning — Okta sends
|
||||
@@ -324,17 +596,27 @@ def patch_user(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
current_fields = _mapping_to_fields(mapping)
|
||||
|
||||
current = _user_to_scim(user, external_id)
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
fields=current_fields,
|
||||
)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(patch_request.Operations, current)
|
||||
patched, ent_data = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
@@ -345,22 +627,60 @@ def patch_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Track the scim_username — if userName was patched, update it
|
||||
new_scim_username = patched.userName.strip() if patched.userName else None
|
||||
|
||||
# If displayName was explicitly patched (different from the original), use
|
||||
# it as personal_name directly. Otherwise, derive from name components.
|
||||
personal_name: str | None
|
||||
if patched.displayName and patched.displayName != current.displayName:
|
||||
personal_name = patched.displayName
|
||||
else:
|
||||
personal_name = _scim_name_to_str(patched.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip().lower()
|
||||
if patched.userName.lower() != user.email
|
||||
patched.userName.strip()
|
||||
if patched.userName.strip().lower() != user.email.lower()
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=_scim_name_to_str(patched.name),
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(user.id, patched.externalId)
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
department=ent_data.get("department", cf.department),
|
||||
manager=ent_data.get("manager", cf.manager),
|
||||
given_name=patched.name.givenName if patched.name else cf.given_name,
|
||||
family_name=patched.name.familyName if patched.name else cf.family_name,
|
||||
scim_emails_json=(
|
||||
serialize_emails(patched.emails)
|
||||
if patched.emails is not None
|
||||
else cf.scim_emails_json
|
||||
),
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(
|
||||
user.id,
|
||||
patched.externalId,
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, patched.externalId)
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
@@ -368,25 +688,29 @@ def delete_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a user (RFC 7644 §3.6).
|
||||
|
||||
Deactivates the user and removes the SCIM mapping. Note that Okta
|
||||
typically uses PATCH active=false instead of DELETE.
|
||||
A second DELETE returns 404 per RFC 7644 §3.6.
|
||||
"""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_user_or_404(user_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
user = result
|
||||
|
||||
dal.deactivate_user(user)
|
||||
|
||||
# If no SCIM mapping exists, the user was already deleted from
|
||||
# SCIM's perspective — return 404 per RFC 7644 §3.6.
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
if mapping:
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
if not mapping:
|
||||
return _scim_error_response(404, f"User {user_id} not found")
|
||||
|
||||
dal.deactivate_user(user)
|
||||
dal.delete_user_mapping(mapping.id)
|
||||
|
||||
dal.commit()
|
||||
|
||||
@@ -398,25 +722,7 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _group_to_scim(
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Convert an Onyx UserGroup to a SCIM Group resource."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
gid = int(group_id)
|
||||
@@ -471,14 +777,17 @@ def _validate_and_parse_members(
|
||||
@scim_router.get("/Groups", response_model=None)
|
||||
def list_groups(
|
||||
filter: str | None = Query(None),
|
||||
excludedAttributes: str | None = None,
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
) -> ScimListResponse | ScimJSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
try:
|
||||
scim_filter = parse_scim_filter(filter)
|
||||
@@ -491,45 +800,59 @@ def list_groups(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
|
||||
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
totalResults=total,
|
||||
startIndex=startIndex,
|
||||
itemsPerPage=count,
|
||||
Resources=resources,
|
||||
return _build_list_response(
|
||||
resources,
|
||||
total,
|
||||
startIndex,
|
||||
count,
|
||||
excluded=_parse_excluded_attributes(excludedAttributes),
|
||||
)
|
||||
|
||||
|
||||
@scim_router.get("/Groups/{group_id}", response_model=None)
|
||||
def get_group(
|
||||
group_id: str,
|
||||
excludedAttributes: str | None = None,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
dal.commit()
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return _group_to_scim(group, members, mapping.external_id if mapping else None)
|
||||
resource = provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
# RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted
|
||||
excluded = _parse_excluded_attributes(excludedAttributes)
|
||||
if excluded:
|
||||
return ScimJSONResponse(content=_apply_exclusions(resource, excluded))
|
||||
|
||||
return _scim_resource_response(resource)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
@@ -565,7 +888,10 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _group_to_scim(db_group, members, external_id)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(db_group, members, external_id),
|
||||
status_code=201,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -573,14 +899,15 @@ def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -595,7 +922,9 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, group_resource.externalId)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, group_resource.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -603,8 +932,9 @@ def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
) -> ScimGroupResource | ScimJSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
|
||||
Handles member add/remove operations from Okta and Azure AD.
|
||||
@@ -613,7 +943,7 @@ def patch_group(
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
@@ -621,11 +951,11 @@ def patch_group(
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = _group_to_scim(group, current_members, external_id)
|
||||
current = provider.build_group_resource(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
@@ -652,7 +982,9 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, patched.externalId)
|
||||
return _scim_resource_response(
|
||||
provider.build_group_resource(group, members, patched.externalId)
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
@@ -660,13 +992,13 @@ def delete_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response | JSONResponse:
|
||||
) -> Response | ScimJSONResponse:
|
||||
"""Delete a group (RFC 7644 §3.6)."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
result = _fetch_group_or_404(group_id, dal)
|
||||
if isinstance(result, JSONResponse):
|
||||
if isinstance(result, ScimJSONResponse):
|
||||
return result
|
||||
group = result
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ import hashlib
|
||||
import secrets
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -28,6 +27,21 @@ from onyx.auth.utils import get_hashed_bearer_token_from_request
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import ScimToken
|
||||
|
||||
|
||||
class ScimAuthError(Exception):
|
||||
"""Raised when SCIM bearer token authentication fails.
|
||||
|
||||
Unlike HTTPException, this carries the status and detail so the SCIM
|
||||
exception handler can wrap them in an RFC 7644 §3.12 error envelope
|
||||
with ``schemas`` and ``status`` fields.
|
||||
"""
|
||||
|
||||
def __init__(self, status_code: int, detail: str) -> None:
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
SCIM_TOKEN_PREFIX = "onyx_scim_"
|
||||
SCIM_TOKEN_LENGTH = 48
|
||||
|
||||
@@ -82,23 +96,14 @@ def verify_scim_token(
|
||||
"""
|
||||
hashed = _get_hashed_scim_token_from_request(request)
|
||||
if not hashed:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing or invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Missing or invalid SCIM bearer token")
|
||||
|
||||
token = dal.get_token_by_hash(hashed)
|
||||
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid SCIM bearer token",
|
||||
)
|
||||
raise ScimAuthError(401, "Invalid SCIM bearer token")
|
||||
|
||||
if not token.is_active:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="SCIM token has been revoked",
|
||||
)
|
||||
raise ScimAuthError(401, "SCIM token has been revoked")
|
||||
|
||||
return token
|
||||
|
||||
@@ -7,12 +7,14 @@ SCIM protocol schemas follow the wire format defined in:
|
||||
Admin API schemas are internal to Onyx and used for SCIM token management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -31,6 +33,9 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
SCIM_ENTERPRISE_USER_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:extension:enterprise:2.0:User"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -63,6 +68,43 @@ class ScimMeta(BaseModel):
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserGroupRef(BaseModel):
|
||||
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
|
||||
|
||||
value: str
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimManagerRef(BaseModel):
|
||||
"""Manager sub-attribute for the enterprise extension (RFC 7643 §4.3)."""
|
||||
|
||||
value: str | None = None
|
||||
|
||||
|
||||
class ScimEnterpriseExtension(BaseModel):
|
||||
"""Enterprise User extension attributes (RFC 7643 §4.3)."""
|
||||
|
||||
department: str | None = None
|
||||
manager: ScimManagerRef | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScimMappingFields:
|
||||
"""Stored SCIM mapping fields that need to round-trip through the IdP.
|
||||
|
||||
Entra ID sends structured name components, email metadata, and enterprise
|
||||
extension attributes that must be returned verbatim in subsequent GET
|
||||
responses. These fields are persisted on ScimUserMapping and threaded
|
||||
through the DAL, provider, and endpoint layers.
|
||||
"""
|
||||
|
||||
department: str | None = None
|
||||
manager: str | None = None
|
||||
given_name: str | None = None
|
||||
family_name: str | None = None
|
||||
scim_emails_json: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -71,14 +113,22 @@ class ScimUserResource(BaseModel):
|
||||
to match the SCIM wire format (not Python convention).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA])
|
||||
id: str | None = None # Onyx's internal user ID, set on responses
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
displayName: str | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
enterprise_extension: ScimEnterpriseExtension | None = Field(
|
||||
default=None,
|
||||
alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User",
|
||||
)
|
||||
|
||||
|
||||
class ScimGroupMember(BaseModel):
|
||||
@@ -121,12 +171,53 @@ class ScimPatchOperationType(str, Enum):
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchResourceValue(BaseModel):
|
||||
"""Partial resource dict for path-less PATCH replace operations.
|
||||
|
||||
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
|
||||
of resource attributes to set. IdPs may include read-only fields
|
||||
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
|
||||
stripped by the provider's ``ignored_patch_paths`` before processing.
|
||||
|
||||
``extra="allow"`` lets unknown attributes pass through so the patch
|
||||
handler can decide what to do with them (ignore or reject).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
active: bool | None = None
|
||||
userName: str | None = None
|
||||
displayName: str | None = None
|
||||
externalId: str | None = None
|
||||
name: ScimName | None = None
|
||||
members: list[ScimGroupMember] | None = None
|
||||
id: str | None = None
|
||||
schemas: list[str] | None = None
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
@field_validator("op", mode="before")
|
||||
@classmethod
|
||||
def normalize_operation(cls, v: object) -> object:
|
||||
"""Normalize op to lowercase for case-insensitive matching.
|
||||
|
||||
Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"``
|
||||
instead of ``"replace"``. This is safe for all providers since the
|
||||
enum values are lowercase. If a future provider requires other
|
||||
pre-processing quirks, move patch deserialization into the provider
|
||||
subclass instead of adding more special cases here.
|
||||
"""
|
||||
return v.lower() if isinstance(v, str) else v
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
@@ -274,6 +365,7 @@ class ScimTokenResponse(BaseModel):
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
last_used_at: datetime | None = None
|
||||
idp_domain: str | None = None
|
||||
|
||||
|
||||
class ScimTokenCreatedResponse(ScimTokenResponse):
|
||||
|
||||
@@ -14,13 +14,70 @@ responsible for persisting changes.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lowercased enterprise extension URN for case-insensitive matching
|
||||
_ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower()
|
||||
|
||||
# Pattern for email filter paths, e.g.:
|
||||
# emails[primary eq true].value (Okta)
|
||||
# emails[type eq "work"].value (Azure AD / Entra ID)
|
||||
_EMAIL_FILTER_RE = re.compile(
|
||||
r"^emails\[.+\]\.value$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispatch tables for user PATCH paths
|
||||
#
|
||||
# Maps lowercased SCIM path → (camelCase key, target dict name).
|
||||
# "data" writes to the top-level resource dict, "name" writes to the
|
||||
# name sub-object dict. This replaces the elif chains for simple fields.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_USER_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"active": ("active", "data"),
|
||||
"username": ("userName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
}
|
||||
|
||||
_USER_REMOVE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"externalid": ("externalId", "data"),
|
||||
"name.givenname": ("givenName", "name"),
|
||||
"name.familyname": ("familyName", "name"),
|
||||
"name.formatted": ("formatted", "name"),
|
||||
"displayname": ("displayName", "data"),
|
||||
}
|
||||
|
||||
_GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = {
|
||||
"displayname": ("displayName", "data"),
|
||||
"externalid": ("externalId", "data"),
|
||||
}
|
||||
|
||||
|
||||
class ScimPatchError(Exception):
|
||||
"""Raised when a PATCH operation cannot be applied."""
|
||||
@@ -31,94 +88,223 @@ class ScimPatchError(Exception):
|
||||
super().__init__(detail)
|
||||
|
||||
|
||||
# Pattern for member removal path: members[value eq "user-id"]
|
||||
_MEMBER_FILTER_RE = re.compile(
|
||||
r'^members\[value\s+eq\s+"([^"]+)"\]$',
|
||||
re.IGNORECASE,
|
||||
)
|
||||
@dataclass
|
||||
class _UserPatchCtx:
|
||||
"""Bundles the mutable state for user PATCH operations."""
|
||||
|
||||
data: dict[str, Any]
|
||||
name_data: dict[str, Any]
|
||||
ent_data: dict[str, str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# User PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
) -> ScimUserResource:
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimUserResource, dict[str, str | None]]:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified user resource, enterprise extension data dict).
|
||||
The enterprise dict has keys ``"department"`` and ``"manager"``
|
||||
with values set only when a PATCH operation touched them.
|
||||
|
||||
Raises:
|
||||
ScimPatchError: If an operation targets an unsupported path.
|
||||
"""
|
||||
data = current.model_dump()
|
||||
name_data = data.get("name") or {}
|
||||
ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {})
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD):
|
||||
_apply_user_replace(op, ctx, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
_apply_user_remove(op, ctx, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
)
|
||||
|
||||
data["name"] = name_data
|
||||
return ScimUserResource.model_validate(data)
|
||||
ctx.data["name"] = ctx.name_data
|
||||
return ScimUserResource.model_validate(ctx.data), ctx.ent_data
|
||||
|
||||
|
||||
def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
# No path — value is a resource dict of top-level attributes to set.
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, ctx, ignored_paths, strict=False)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
_set_user_field(path, op.value, ctx, ignored_paths)
|
||||
|
||||
|
||||
def _apply_user_remove(
|
||||
op: ScimPatchOperation,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a remove operation to user data — clears the target field."""
|
||||
path = (op.path or "").lower()
|
||||
if not path:
|
||||
raise ScimPatchError("Remove operation requires a path")
|
||||
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _USER_REMOVE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = None
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
value: ScimPatchValue,
|
||||
ctx: _UserPatchCtx,
|
||||
ignored_paths: frozenset[str],
|
||||
*,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
elif path == "name.givenname":
|
||||
name_data["givenName"] = value
|
||||
elif path == "name.familyname":
|
||||
name_data["familyName"] = value
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
name_data["formatted"] = value
|
||||
"""Set a single field on user data by SCIM path.
|
||||
|
||||
Args:
|
||||
strict: When ``False`` (path-less replace), unknown attributes are
|
||||
silently skipped. When ``True`` (explicit path), they raise.
|
||||
"""
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
# Simple field writes handled by the dispatch table
|
||||
entry = _USER_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, target = entry
|
||||
target_dict = ctx.data if target == "data" else ctx.name_data
|
||||
target_dict[key] = value
|
||||
return
|
||||
|
||||
# displayName sets both the top-level field and the name.formatted sub-field
|
||||
if path == "displayname":
|
||||
ctx.data["displayName"] = value
|
||||
ctx.name_data["formatted"] = value
|
||||
elif path == "name":
|
||||
if isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
ctx.name_data[k] = v
|
||||
elif path == "emails":
|
||||
if isinstance(value, list):
|
||||
ctx.data["emails"] = value
|
||||
elif _EMAIL_FILTER_RE.match(path):
|
||||
_update_primary_email(ctx.data, value)
|
||||
elif path.startswith(_ENTERPRISE_URN_LOWER):
|
||||
_set_enterprise_field(path, value, ctx.ent_data)
|
||||
elif not strict:
|
||||
return
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
|
||||
|
||||
def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None:
|
||||
"""Update the primary email entry via an email filter path."""
|
||||
emails: list[dict] = data.get("emails") or []
|
||||
for email_entry in emails:
|
||||
if email_entry.get("primary"):
|
||||
email_entry["value"] = value
|
||||
break
|
||||
else:
|
||||
emails.append({"value": value, "type": "work", "primary": True})
|
||||
data["emails"] = emails
|
||||
|
||||
|
||||
def _to_dict(value: ScimPatchValue) -> dict | None:
|
||||
"""Coerce a SCIM patch value to a plain dict if possible.
|
||||
|
||||
Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses
|
||||
``extra="allow"``), so we also dump those back to a dict.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, ScimPatchResourceValue):
|
||||
return value.model_dump(exclude_unset=True)
|
||||
return None
|
||||
|
||||
|
||||
def _set_enterprise_field(
|
||||
path: str,
|
||||
value: ScimPatchValue,
|
||||
ent_data: dict[str, str | None],
|
||||
) -> None:
|
||||
"""Handle enterprise extension URN paths or value dicts."""
|
||||
# Full URN as key with dict value (path-less PATCH)
|
||||
# e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}}
|
||||
if path == _ENTERPRISE_URN_LOWER:
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
if "department" in d:
|
||||
ent_data["department"] = d["department"]
|
||||
if "manager" in d:
|
||||
mgr = d["manager"]
|
||||
if isinstance(mgr, dict):
|
||||
ent_data["manager"] = mgr.get("value")
|
||||
return
|
||||
|
||||
# Dotted URN path, e.g. "urn:...:user:department"
|
||||
suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower()
|
||||
if suffix == "department":
|
||||
ent_data["department"] = str(value) if value is not None else None
|
||||
elif suffix == "manager":
|
||||
d = _to_dict(value)
|
||||
if d is not None:
|
||||
ent_data["manager"] = d.get("value")
|
||||
elif isinstance(value, str):
|
||||
ent_data["manager"] = value
|
||||
else:
|
||||
# Unknown enterprise attributes are silently ignored rather than
|
||||
# rejected — IdPs may send attributes we don't model yet.
|
||||
logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Group PATCH
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current group resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
@@ -133,7 +319,9 @@ def apply_group_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
_apply_group_replace(
|
||||
op, data, current_members, added_ids, removed_ids, ignored_paths
|
||||
)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
@@ -154,38 +342,48 @@ def _apply_group_replace(
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
dumped = op.value.model_dump(exclude_unset=True)
|
||||
for key, val in dumped.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
_set_group_field(key.lower(), val, data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
_replace_members(
|
||||
_members_to_dicts(op.value), current_members, added_ids, removed_ids
|
||||
)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
_set_group_field(path, op.value, data, ignored_paths)
|
||||
|
||||
|
||||
def _members_to_dicts(
|
||||
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
|
||||
) -> list[dict]:
|
||||
"""Convert a member list value to a list of dicts for internal processing."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
return [m.model_dump(exclude_none=True) for m in value]
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
value: list[dict],
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
@@ -197,16 +395,21 @@ def _replace_members(
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
if path in ignored_paths:
|
||||
return
|
||||
|
||||
entry = _GROUP_REPLACE_PATHS.get(path)
|
||||
if entry:
|
||||
key, _ = entry
|
||||
data[key] = value
|
||||
return
|
||||
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH")
|
||||
|
||||
|
||||
def _apply_group_add(
|
||||
@@ -223,8 +426,10 @@ def _apply_group_add(
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
for member_data in member_dicts:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
|
||||
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
215
backend/ee/onyx/server/scim/providers/base.py
Normal file
215
backend/ee/onyx/server/scim/providers/base.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Base SCIM provider abstraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimEnterpriseExtension
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimManagerRef
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"id",
|
||||
"schemas",
|
||||
"meta",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
Subclass this to handle IdP-specific quirks. The base class provides
|
||||
RFC 7643-compliant response builders that populate all standard fields.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Short identifier for this provider (e.g. ``"okta"``)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
|
||||
|
||||
IdPs may include read-only or meta fields alongside actual changes
|
||||
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
|
||||
here are silently dropped instead of raising an error.
|
||||
"""
|
||||
...
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
"""Schema URIs to include in User resource responses.
|
||||
|
||||
Override in subclasses to advertise additional schemas (e.g. the
|
||||
enterprise extension for Entra ID).
|
||||
"""
|
||||
return [SCIM_USER_SCHEMA]
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
fields: ScimMappingFields | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
Args:
|
||||
user: The Onyx user model.
|
||||
external_id: The IdP's external identifier for this user.
|
||||
groups: List of ``(group_id, group_name)`` tuples for the
|
||||
``groups`` read-only attribute. Pass ``None`` or ``[]``
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
fields: Stored mapping fields that the IdP expects round-tripped.
|
||||
"""
|
||||
f = fields or ScimMappingFields()
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
username = scim_username or user.email
|
||||
|
||||
# Build enterprise extension when at least one value is present.
|
||||
# Dynamically add the enterprise URN to schemas per RFC 7643 §3.0.
|
||||
enterprise_ext: ScimEnterpriseExtension | None = None
|
||||
schemas = list(self.user_schemas)
|
||||
if f.department is not None or f.manager is not None:
|
||||
manager_ref = (
|
||||
ScimManagerRef(value=f.manager) if f.manager is not None else None
|
||||
)
|
||||
enterprise_ext = ScimEnterpriseExtension(
|
||||
department=f.department,
|
||||
manager=manager_ref,
|
||||
)
|
||||
if SCIM_ENTERPRISE_USER_SCHEMA not in schemas:
|
||||
schemas.append(SCIM_ENTERPRISE_USER_SCHEMA)
|
||||
|
||||
name = self.build_scim_name(user, f)
|
||||
emails = _deserialize_emails(f.scim_emails_json, username)
|
||||
|
||||
resource = ScimUserResource(
|
||||
schemas=schemas,
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=name,
|
||||
displayName=user.personal_name,
|
||||
emails=emails,
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
resource.enterprise_extension = enterprise_ext
|
||||
return resource
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Build a SCIM Group response from an Onyx UserGroup."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
def build_scim_name(
|
||||
self,
|
||||
user: User,
|
||||
fields: ScimMappingFields,
|
||||
) -> ScimName:
|
||||
"""Build SCIM name components for the response.
|
||||
|
||||
Round-trips stored ``given_name``/``family_name`` when available (so
|
||||
the IdP gets back what it sent). Falls back to splitting
|
||||
``personal_name`` for users provisioned before we stored components.
|
||||
Always returns a ScimName — Okta's spec tests expect ``name``
|
||||
(with ``givenName``/``familyName``) on every user resource.
|
||||
Providers may override for custom behavior.
|
||||
"""
|
||||
if fields.given_name is not None or fields.family_name is not None:
|
||||
return ScimName(
|
||||
givenName=fields.given_name or "",
|
||||
familyName=fields.family_name or "",
|
||||
formatted=user.personal_name or "",
|
||||
)
|
||||
if not user.personal_name:
|
||||
# Derive a reasonable name from the email so that SCIM spec tests
|
||||
# see non-empty givenName / familyName for every user resource.
|
||||
local = user.email.split("@")[0] if user.email else ""
|
||||
return ScimName(givenName=local, familyName="", formatted=local)
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else "",
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]:
|
||||
"""Deserialize stored email entries or build a default work email."""
|
||||
if stored_json:
|
||||
try:
|
||||
entries = json.loads(stored_json)
|
||||
if isinstance(entries, list) and entries:
|
||||
return [ScimEmail(**e) for e in entries]
|
||||
except (json.JSONDecodeError, TypeError, ValidationError):
|
||||
logger.warning(
|
||||
"Corrupt scim_emails_json, falling back to default: %s", stored_json
|
||||
)
|
||||
return [ScimEmail(value=username, type="work", primary=True)]
|
||||
|
||||
|
||||
def serialize_emails(emails: list[ScimEmail]) -> str | None:
|
||||
"""Serialize SCIM email entries to JSON for storage."""
|
||||
if not emails:
|
||||
return None
|
||||
return json.dumps([e.model_dump(exclude_none=True) for e in emails])
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
Currently returns ``OktaProvider`` since Okta is the primary supported
|
||||
IdP. When provider detection is added (via token metadata or tenant
|
||||
config), this can be replaced with dynamic resolution.
|
||||
"""
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
return OktaProvider()
|
||||
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
36
backend/ee/onyx/server/scim/providers/entra.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Entra ID (Azure AD) SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
_ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS
|
||||
|
||||
|
||||
class EntraProvider(ScimProvider):
|
||||
"""Entra ID (Azure AD) SCIM provider.
|
||||
|
||||
Entra behavioral notes:
|
||||
- Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``)
|
||||
— handled by ``ScimPatchOperation.normalize_op`` validator.
|
||||
- Sends the enterprise extension URN as a key in path-less PATCH value
|
||||
dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to
|
||||
store department/manager values.
|
||||
- Expects the enterprise extension schema in ``schemas`` arrays and
|
||||
``/Schemas`` + ``/ResourceTypes`` discovery endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "entra"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return _ENTRA_IGNORED_PATCH_PATHS
|
||||
|
||||
@property
|
||||
def user_schemas(self) -> list[str]:
|
||||
return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA]
|
||||
26
backend/ee/onyx/server/scim/providers/okta.py
Normal file
26
backend/ee/onyx/server/scim/providers/okta.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Okta SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
class OktaProvider(ScimProvider):
|
||||
"""Okta SCIM provider.
|
||||
|
||||
Okta behavioral notes:
|
||||
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
|
||||
- Sends path-less PATCH with value dicts containing extra fields
|
||||
(``id``, ``schemas``)
|
||||
- Expects ``displayName`` and ``groups`` in user responses
|
||||
- Only uses ``eq`` operator for ``userName`` filter
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "okta"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return COMMON_IGNORED_PATCH_PATHS
|
||||
@@ -4,6 +4,7 @@ Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -20,6 +21,9 @@ USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
"schemaExtensions": [
|
||||
{"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -104,6 +108,31 @@ USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
],
|
||||
)
|
||||
|
||||
ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_ENTERPRISE_USER_SCHEMA,
|
||||
name="EnterpriseUser",
|
||||
description="Enterprise User extension (RFC 7643 §4.3)",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="department",
|
||||
type="string",
|
||||
description="Department.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="manager",
|
||||
type="complex",
|
||||
description="The user's manager.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Manager user ID.",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
|
||||
@@ -18,14 +18,15 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import Tool
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import LLMProviderView
|
||||
from onyx.server.settings.models import Settings
|
||||
from onyx.server.settings.store import store_settings as store_base_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -117,15 +118,44 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
|
||||
def _seed_llms(
|
||||
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
|
||||
) -> None:
|
||||
if llm_upsert_requests:
|
||||
logger.notice("Seeding LLMs")
|
||||
seeded_providers = [
|
||||
upsert_llm_provider(llm_upsert_request, db_session)
|
||||
for llm_upsert_request in llm_upsert_requests
|
||||
]
|
||||
update_default_provider(
|
||||
provider_id=seeded_providers[0].id, db_session=db_session
|
||||
)
|
||||
if not llm_upsert_requests:
|
||||
return
|
||||
|
||||
logger.notice("Seeding LLMs")
|
||||
for request in llm_upsert_requests:
|
||||
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
seeded_providers: list[LLMProviderView] = []
|
||||
for llm_upsert_request in llm_upsert_requests:
|
||||
try:
|
||||
seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session))
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Failed to upsert LLM provider '%s' during seeding: %s",
|
||||
llm_upsert_request.name,
|
||||
e,
|
||||
)
|
||||
|
||||
default_provider = next(
|
||||
(p for p in seeded_providers if p.model_configurations), None
|
||||
)
|
||||
if not default_provider:
|
||||
return
|
||||
|
||||
visible_configs = [
|
||||
mc for mc in default_provider.model_configurations if mc.is_visible
|
||||
]
|
||||
default_config = (
|
||||
visible_configs[0]
|
||||
if visible_configs
|
||||
else default_provider.model_configurations[0]
|
||||
)
|
||||
update_default_provider(
|
||||
provider_id=default_provider.id,
|
||||
model_name=default_config.name,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
|
||||
@@ -137,12 +167,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
@@ -154,6 +178,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.exc import SQLAlchemyError
|
||||
from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED
|
||||
from ee.onyx.db.license import get_cached_license_metadata
|
||||
from ee.onyx.db.license import refresh_license_cache
|
||||
from onyx.cache.interface import CACHE_TRANSIENT_ERRORS
|
||||
from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
@@ -109,6 +110,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
if metadata.status == _BLOCKING_STATUS:
|
||||
settings.application_status = metadata.status
|
||||
settings.ee_features_enabled = False
|
||||
elif metadata.used_seats > metadata.seats:
|
||||
# License is valid but seat limit exceeded
|
||||
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
|
||||
settings.seat_count = metadata.seats
|
||||
settings.used_seats = metadata.used_seats
|
||||
settings.ee_features_enabled = True
|
||||
else:
|
||||
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
|
||||
settings.ee_features_enabled = True
|
||||
@@ -119,7 +126,7 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
|
||||
# syncing) means indexed data may need protection.
|
||||
settings.application_status = _BLOCKING_STATUS
|
||||
settings.ee_features_enabled = False
|
||||
except RedisError as e:
|
||||
except CACHE_TRANSIENT_ERRORS as e:
|
||||
logger.warning(f"Failed to check license metadata for settings: {e}")
|
||||
# Fail closed - disable EE features if we can't verify license
|
||||
settings.ee_features_enabled = False
|
||||
|
||||
@@ -21,7 +21,6 @@ import asyncio
|
||||
import httpx
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
@@ -43,6 +42,8 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE
|
||||
from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -116,9 +117,14 @@ async def create_customer_portal_session(
|
||||
try:
|
||||
portal_url = fetch_customer_portal_session(tenant_id, return_url)
|
||||
return {"stripe_customer_portal_url": portal_url}
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create customer portal session",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
@@ -134,9 +140,14 @@ async def create_checkout_session(
|
||||
try:
|
||||
checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats)
|
||||
return {"stripe_checkout_url": checkout_url}
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create checkout session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create checkout session",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
@@ -147,15 +158,20 @@ async def create_subscription_session(
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found")
|
||||
|
||||
billing_period = request.billing_period if request else "monthly"
|
||||
session_id = fetch_stripe_checkout_session(tenant_id, billing_period)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to create subscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to create subscription session",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/stripe-publishable-key")
|
||||
@@ -186,18 +202,18 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
if STRIPE_PUBLISHABLE_KEY_OVERRIDE:
|
||||
key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip()
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
|
||||
# Fall back to S3 bucket
|
||||
if not STRIPE_PUBLISHABLE_KEY_URL:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Stripe publishable key is not configured",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Stripe publishable key is not configured",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -208,15 +224,15 @@ async def get_stripe_publishable_key() -> StripePublishableKeyResponse:
|
||||
|
||||
# Validate key format
|
||||
if not key.startswith("pk_"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Invalid Stripe publishable key format",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Invalid Stripe publishable key format",
|
||||
)
|
||||
|
||||
_stripe_publishable_key_cache = key
|
||||
return StripePublishableKeyResponse(publishable_key=key)
|
||||
except httpx.HTTPError:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch Stripe publishable key",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INTERNAL_ERROR,
|
||||
"Failed to fetch Stripe publishable key",
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
|
||||
has_set_default_provider = False
|
||||
|
||||
def _upsert(request: LLMProviderUpsertRequest) -> None:
|
||||
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
|
||||
nonlocal has_set_default_provider
|
||||
try:
|
||||
existing = fetch_existing_llm_provider(
|
||||
name=request.name, db_session=db_session
|
||||
)
|
||||
if existing:
|
||||
request.id = existing.id
|
||||
provider = upsert_llm_provider(request, db_session)
|
||||
if not has_set_default_provider:
|
||||
update_default_provider(provider.id, db_session)
|
||||
update_default_provider(provider.id, default_model, db_session)
|
||||
has_set_default_provider = True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure {request.provider} provider: {e}")
|
||||
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
OPENAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openai_provider)
|
||||
_upsert(openai_provider, default_model_name)
|
||||
|
||||
# Create default image generation config using the OpenAI API key
|
||||
try:
|
||||
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
ANTHROPIC_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(anthropic_provider)
|
||||
_upsert(anthropic_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Google Vertex AI",
|
||||
provider=VERTEXAI_PROVIDER_NAME,
|
||||
custom_config=custom_config,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=_build_model_configuration_upsert_requests(
|
||||
VERTEXAI_PROVIDER_NAME, recommendations
|
||||
),
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(vertexai_provider)
|
||||
_upsert(vertexai_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
|
||||
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="OpenRouter",
|
||||
provider=OPENROUTER_PROVIDER_NAME,
|
||||
api_key=OPENROUTER_DEFAULT_API_KEY,
|
||||
default_model_name=default_model_name,
|
||||
model_configurations=model_configurations,
|
||||
api_key_changed=True,
|
||||
is_auto_mode=True,
|
||||
)
|
||||
_upsert(openrouter_provider)
|
||||
_upsert(openrouter_provider, default_model_name)
|
||||
else:
|
||||
logger.info(
|
||||
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"
|
||||
|
||||
@@ -5,6 +5,8 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.user_group import add_users_to_user_group
|
||||
from ee.onyx.db.user_group import delete_user_group as db_delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_groups
|
||||
from ee.onyx.db.user_group import fetch_user_groups_for_user
|
||||
from ee.onyx.db.user_group import insert_user_group
|
||||
@@ -20,6 +22,7 @@ from ee.onyx.server.user_group.models import UserGroupUpdate
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
@@ -37,12 +40,15 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
@@ -150,3 +156,8 @@ def delete_user_group(
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
user_group = fetch_user_group(db_session, user_group_id)
|
||||
if user_group:
|
||||
db_delete_user_group(db_session, user_group)
|
||||
|
||||
@@ -53,7 +53,8 @@ class UserGroup(BaseModel):
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
cc_pair_relationship.cc_pair.connector,
|
||||
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
|
||||
@@ -14,67 +14,91 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
return encoded_key
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_str.encode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_bytes.decode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
@@ -12,6 +11,7 @@ from onyx.db.document import get_access_info_for_document
|
||||
from onyx.db.document import get_access_info_for_documents
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.user_file import fetch_user_files_with_access_relationships
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
@@ -132,19 +132,61 @@ def get_access_for_user_files(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = (
|
||||
db_session.query(UserFile)
|
||||
.options(joinedload(UserFile.user)) # Eager load the user relationship
|
||||
.filter(UserFile.id.in_(user_file_ids))
|
||||
.all()
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "get_access_for_user_files_impl"
|
||||
)
|
||||
return {
|
||||
str(user_file.id): DocumentAccess.build(
|
||||
user_emails=[user_file.user.email] if user_file.user else [],
|
||||
return versioned_fn(user_file_ids, db_session)
|
||||
|
||||
|
||||
def get_access_for_user_files_impl(
|
||||
user_file_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session)
|
||||
return build_access_for_user_files_impl(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Compute access from pre-loaded UserFile objects (with relationships).
|
||||
Callers must ensure UserFile.user, Persona.users, and Persona.user are
|
||||
eagerly loaded (and Persona.groups for the EE path)."""
|
||||
versioned_fn = fetch_versioned_implementation(
|
||||
"onyx.access.access", "build_access_for_user_files_impl"
|
||||
)
|
||||
return versioned_fn(user_files)
|
||||
|
||||
|
||||
def build_access_for_user_files_impl(
|
||||
user_files: list[UserFile],
|
||||
) -> dict[str, DocumentAccess]:
|
||||
result: dict[str, DocumentAccess] = {}
|
||||
for user_file in user_files:
|
||||
emails, is_public = collect_user_file_access(user_file)
|
||||
result[str(user_file.id)] = DocumentAccess.build(
|
||||
user_emails=list(emails),
|
||||
user_groups=[],
|
||||
is_public=True if user_file.user is None else False,
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
)
|
||||
for user_file in user_files
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]:
|
||||
"""Collect all user emails that should have access to this user file.
|
||||
Includes the owner plus any users who have access via shared personas.
|
||||
Returns (emails, is_public)."""
|
||||
emails: set[str] = {user_file.user.email}
|
||||
is_public = False
|
||||
for persona in user_file.assistants:
|
||||
if persona.deleted:
|
||||
continue
|
||||
if persona.is_public:
|
||||
is_public = True
|
||||
if persona.user_id is not None and persona.user:
|
||||
emails.add(persona.user.email)
|
||||
for shared_user in persona.users:
|
||||
emails.add(shared_user.email)
|
||||
return emails, is_public
|
||||
|
||||
@@ -58,16 +58,27 @@ class OAuthTokenManager:
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for token refresh"
|
||||
)
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": token_data["refresh_token"],
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -115,15 +126,26 @@ class OAuthTokenManager:
|
||||
|
||||
def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]:
|
||||
"""Exchange authorization code for access token"""
|
||||
if (
|
||||
self.oauth_config.client_id is None
|
||||
or self.oauth_config.client_secret is None
|
||||
):
|
||||
raise ValueError(
|
||||
"OAuth client_id and client_secret are required for code exchange"
|
||||
)
|
||||
|
||||
data: dict[str, str] = {
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self._unwrap_sensitive_str(self.oauth_config.client_id),
|
||||
"client_secret": self._unwrap_sensitive_str(
|
||||
self.oauth_config.client_secret
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
}
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
data={
|
||||
"grant_type": "authorization_code",
|
||||
"code": code,
|
||||
"client_id": self.oauth_config.client_id,
|
||||
"client_secret": self.oauth_config.client_secret,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
data=data,
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
@@ -141,8 +163,13 @@ class OAuthTokenManager:
|
||||
oauth_config: OAuthConfig, redirect_uri: str, state: str
|
||||
) -> str:
|
||||
"""Build OAuth authorization URL"""
|
||||
if oauth_config.client_id is None:
|
||||
raise ValueError("OAuth client_id is required to build authorization URL")
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"client_id": oauth_config.client_id,
|
||||
"client_id": OAuthTokenManager._unwrap_sensitive_str(
|
||||
oauth_config.client_id
|
||||
),
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"state": state,
|
||||
@@ -161,6 +188,12 @@ class OAuthTokenManager:
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str:
|
||||
if isinstance(value, SensitiveValue):
|
||||
return value.get_value(apply_mask=False)
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
@@ -120,7 +121,7 @@ from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -137,16 +138,30 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
|
||||
|
||||
|
||||
def is_user_admin(user: User) -> bool:
|
||||
return user.role == UserRole.ADMIN
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
"""Log warnings for AUTH_TYPE issues.
|
||||
|
||||
This only runs on app startup not during migrations/scripts.
|
||||
"""
|
||||
raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower()
|
||||
|
||||
if raw_auth_type == "cloud":
|
||||
raise ValueError(
|
||||
f"{AUTH_TYPE.value} is not a valid auth type for self-hosted deployments."
|
||||
"'cloud' is not a valid auth type for self-hosted deployments."
|
||||
)
|
||||
if raw_auth_type == "disabled":
|
||||
logger.warning(
|
||||
"AUTH_TYPE='disabled' is no longer supported. "
|
||||
"Using 'basic' instead. Please update your configuration."
|
||||
)
|
||||
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
@@ -198,32 +213,45 @@ def user_needs_to_be_verified() -> bool:
|
||||
|
||||
|
||||
def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
from onyx.cache.factory import get_cache_backend
|
||||
|
||||
cache = get_cache_backend(tenant_id=tenant_id)
|
||||
value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def workspace_invite_only_enabled() -> bool:
|
||||
settings = load_settings()
|
||||
return settings.invite_only_enabled
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
|
||||
# SSO providers manage membership; allow JIT provisioning regardless of invites
|
||||
return
|
||||
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
if not workspace_invite_only_enabled():
|
||||
return
|
||||
|
||||
whitelist = get_invited_users()
|
||||
|
||||
if not email:
|
||||
raise PermissionError("Email must be specified")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email must be specified"},
|
||||
)
|
||||
|
||||
try:
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise PermissionError("Email is not valid")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email is not valid"},
|
||||
)
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -240,7 +268,13 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": REGISTER_INVITE_ONLY_CODE,
|
||||
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
@@ -256,13 +290,32 @@ def verify_email_domain(email: str) -> None:
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
domain = email.split("@")[-1].lower()
|
||||
local_part, domain = email.split("@")
|
||||
domain = domain.lower()
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
|
||||
if domain == "googlemail.com":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
|
||||
)
|
||||
|
||||
if "+" in local_part and domain != "onyx.app":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
detail={
|
||||
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
@@ -503,7 +556,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
@@ -685,11 +738,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if user_by_session:
|
||||
user = user_by_session
|
||||
|
||||
# If the user is inactive, check seat availability before
|
||||
# upgrading role — otherwise they'd become an inactive BASIC
|
||||
# user who still can't log in.
|
||||
if not user.is_active:
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1650,7 +1711,10 @@ def get_oauth_router(
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
# Use WEB_DOMAIN instead of request.url_for() to prevent host
|
||||
# header poisoning — request.url_for() trusts the Host header.
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery import Task
|
||||
from celery.apps.worker import Worker
|
||||
from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_process_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.background")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def on_task_postrun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple | None = None,
|
||||
kwargs: dict | None = None,
|
||||
retval: Any | None = None,
|
||||
state: str | None = None,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received for consolidated background worker.")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_BACKGROUND_APP_NAME)
|
||||
pool_size = cast(int, sender.concurrency) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=pool_size, max_overflow=EXTRA_CONCURRENCY)
|
||||
|
||||
# Initialize Vespa httpx pool (needed for light worker tasks)
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_shutdown.connect
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
|
||||
@signals.setup_logging.connect
|
||||
def on_setup_logging(
|
||||
loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
app_base.filter_task_modules(
|
||||
[
|
||||
# Original background worker tasks
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.user_file_processing",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
# Light worker tasks
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.opensearch_migration",
|
||||
# Docprocessing worker tasks
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
# Docfetching worker tasks
|
||||
"onyx.background.celery.tasks.docfetching",
|
||||
# Sandbox cleanup tasks (isolated in build feature)
|
||||
"onyx.server.features.build.sandbox.tasks",
|
||||
]
|
||||
)
|
||||
)
|
||||
@@ -39,9 +39,13 @@ CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
class SlimConnectorExtractionResult(BaseModel):
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector."""
|
||||
"""Result of extracting document IDs and hierarchy nodes from a connector.
|
||||
|
||||
doc_ids: set[str]
|
||||
raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None).
|
||||
Use raw_id_to_parent.keys() wherever the old set of IDs was needed.
|
||||
"""
|
||||
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
@@ -93,30 +97,34 @@ def _get_failure_id(failure: ConnectorFailure) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class BatchResult(BaseModel):
|
||||
raw_id_to_parent: dict[str, str | None]
|
||||
hierarchy_nodes: list[HierarchyNode]
|
||||
|
||||
|
||||
def _extract_from_batch(
|
||||
doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure],
|
||||
) -> tuple[set[str], list[HierarchyNode]]:
|
||||
"""Separate a batch into document IDs and hierarchy nodes.
|
||||
) -> BatchResult:
|
||||
"""Separate a batch into document IDs (with parent mapping) and hierarchy nodes.
|
||||
|
||||
ConnectorFailure items have their failed document/entity IDs added to the
|
||||
ID set so that failed-to-retrieve documents are not accidentally pruned.
|
||||
ID dict so that failed-to-retrieve documents are not accidentally pruned.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
ids: dict[str, str | None] = {}
|
||||
hierarchy_nodes: list[HierarchyNode] = []
|
||||
for item in doc_list:
|
||||
if isinstance(item, HierarchyNode):
|
||||
hierarchy_nodes.append(item)
|
||||
ids.add(item.raw_node_id)
|
||||
elif isinstance(item, ConnectorFailure):
|
||||
failed_id = _get_failure_id(item)
|
||||
if failed_id:
|
||||
ids.add(failed_id)
|
||||
ids[failed_id] = None
|
||||
logger.warning(
|
||||
f"Failed to retrieve document {failed_id}: " f"{item.failure_message}"
|
||||
)
|
||||
else:
|
||||
ids.add(item.id)
|
||||
return ids, hierarchy_nodes
|
||||
ids[item.id] = item.parent_hierarchy_raw_node_id
|
||||
return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes)
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -132,7 +140,7 @@ def extract_ids_from_runnable_connector(
|
||||
|
||||
Optionally, a callback can be passed to handle the length of each document batch.
|
||||
"""
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
all_raw_id_to_parent: dict[str, str | None] = {}
|
||||
all_hierarchy_nodes: list[HierarchyNode] = []
|
||||
|
||||
# Sequence (covariant) lets all the specific list[...] iterator types unify here
|
||||
@@ -177,15 +185,18 @@ def extract_ids_from_runnable_connector(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
batch_ids, batch_nodes = _extract_from_batch(doc_list)
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(batch_ids))
|
||||
batch_result = _extract_from_batch(doc_list)
|
||||
batch_ids = batch_result.raw_id_to_parent
|
||||
batch_nodes = batch_result.hierarchy_nodes
|
||||
doc_batch_processing_func(batch_ids)
|
||||
all_raw_id_to_parent.update(batch_ids)
|
||||
all_hierarchy_nodes.extend(batch_nodes)
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(batch_ids))
|
||||
|
||||
return SlimConnectorExtractionResult(
|
||||
doc_ids=all_connector_doc_ids,
|
||||
raw_id_to_parent=all_raw_id_to_parent,
|
||||
hierarchy_nodes=all_hierarchy_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
from onyx.configs.app_configs import CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
worker_concurrency = CELERY_WORKER_BACKGROUND_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
# Increased from 1 to 4 to handle fast light worker tasks more efficiently
|
||||
# This allows the worker to prefetch multiple tasks per thread
|
||||
worker_prefetch_multiplier = 4
|
||||
@@ -241,8 +241,7 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user