Compare commits

..

1 Commits
gating ... a

Author SHA1 Message Date
pablodanswer
f47d6798e1 temp 2024-10-22 09:33:41 -07:00
18 changed files with 69 additions and 84 deletions

View File

@@ -136,7 +136,6 @@ DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FIL
class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial
class BlobType(str, Enum):

View File

@@ -268,27 +268,34 @@ async def get_async_session_with_tenant(
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
raise Exception("Invalid tenant ID")
engine = get_sqlalchemy_async_engine()
async_session_factory = sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession
) # type: ignore
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
async with async_session_factory() as session:
try:
# Set the search_path to the tenant's schema
# Start a SAVEPOINT to ensure the SET command is effective
async with session.begin():
# Set the search_path at the session level
await session.execute(text(f'SET search_path = "{tenant_id}"'))
except Exception as e:
logger.error(f"Error setting search_path: {str(e)}")
# You can choose to re-raise the exception or handle it
# Here, we'll re-raise to prevent proceeding with an incorrect session
raise
else:
try:
yield session
finally:
# Optionally reset the search_path after the session ends
if MULTI_TENANT:
async with session.begin():
await session.execute(text('SET search_path TO "$user", public'))
@contextmanager

View File

@@ -113,7 +113,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
"OAuthAccount", lazy="selectin", cascade="all, delete-orphan"
)
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)

View File

@@ -4,7 +4,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from danswer.auth.schemas import UserRole
from danswer.configs.constants import NotificationType
from danswer.db.models import Notification
from danswer.db.models import User
@@ -55,9 +54,7 @@ def get_notification_by_id(
notif = db_session.get(Notification, notification_id)
if not notif:
raise ValueError(f"No notification found with id {notification_id}")
if notif.user_id != user_id and not (
notif.user_id is None and user.role == UserRole.ADMIN
):
if notif.user_id != user_id:
raise PermissionError(
f"User {user_id} is not authorized to access notification {notification_id}"
)

View File

@@ -79,8 +79,6 @@ def _get_answer_stream_processor(
doc_id_to_rank_map: DocumentIdOrderMapping,
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
print("ANSWERR STYES")
print(answer_style_configs.__dict__)
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=context_docs, doc_id_to_rank_map=doc_id_to_rank_map

View File

@@ -226,7 +226,6 @@ def process_model_tokens(
hold_quote = ""
for token in tokens:
print(f"Token: {token}")
model_previous = model_output
model_output += token

View File

@@ -54,7 +54,6 @@ def fetch_settings(
Postgres calls"""
general_settings = load_settings()
user_notifications = get_reindex_notification(user, db_session)
product_gating_notification = get_product_gating_notification(db_session)
try:
kv_store = get_kv_store()
@@ -62,27 +61,11 @@ def fetch_settings(
except KvKeyNotFoundError:
needs_reindexing = False
print("product_gating_notification", product_gating_notification)
# TODO: Clean up
print("response is ", [product_gating_notification])
response = UserSettings(
return UserSettings(
**general_settings.model_dump(),
notifications=[product_gating_notification]
if product_gating_notification
else user_notifications,
notifications=user_notifications,
needs_reindexing=needs_reindexing,
)
print("act is ", response)
return response
def get_product_gating_notification(db_session: Session) -> Notification | None:
notification = get_notifications(
user=None,
notif_type=NotificationType.TRIAL_ENDS_TWO_DAYS,
db_session=db_session,
)
return Notification.from_model(notification[0]) if notification else None
def get_reindex_notification(

View File

@@ -8,7 +8,6 @@ from danswer.auth.users import User
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
from danswer.db.notification import create_notification
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
@@ -88,17 +87,12 @@ def gate_product(
1) User has ended free trial without adding payment method
2) User's card has declined
"""
tenant_id = product_gating_request.tenant_id
token = current_tenant_id.set(tenant_id)
token = current_tenant_id.set(current_tenant_id.get())
settings = load_settings()
settings.product_gating = product_gating_request.product_gating
store_settings(settings)
if product_gating_request.notification:
with get_session_with_tenant(tenant_id) as db_session:
create_notification(None, product_gating_request.notification, db_session)
if token is not None:
current_tenant_id.reset(token)

View File

@@ -1,6 +1,5 @@
from pydantic import BaseModel
from danswer.configs.constants import NotificationType
from danswer.server.settings.models import GatingType
@@ -16,7 +15,6 @@ class CreateTenantRequest(BaseModel):
class ProductGatingRequest(BaseModel):
tenant_id: str
product_gating: GatingType
notification: NotificationType | None = None
class BillingInformation(BaseModel):

View File

@@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5433:5432"
- "5432:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -313,7 +313,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5433:5432"
- "5432:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -157,7 +157,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5433"
- "5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -18,7 +18,6 @@ export interface Settings {
export enum NotificationType {
PERSONA_SHARED = "persona_shared",
REINDEX_NEEDED = "reindex_needed",
TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending",
}
export interface Notification {

View File

@@ -1,4 +1,4 @@
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { CLOUD_ENABLED } from "@/lib/constants";
import { getAuthTypeMetadataSS, logoutSS } from "@/lib/userSS";
import { NextRequest } from "next/server";
@@ -13,7 +13,7 @@ export const POST = async (request: NextRequest) => {
}
// Delete cookies only if cloud is enabled (jwt auth)
if (NEXT_PUBLIC_CLOUD_ENABLED) {
if (CLOUD_ENABLED) {
const cookiesToDelete = ["fastapiusersauth", "tenant_details"];
const cookieOptions = {
path: "/",

View File

@@ -8,8 +8,10 @@ import {
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { EmailPasswordForm } from "../login/EmailPasswordForm";
import { Text } from "@tremor/react";
import { Card, Title, Text } from "@tremor/react";
import Link from "next/link";
import { Logo } from "@/components/Logo";
import { CLOUD_ENABLED } from "@/lib/constants";
import { SignInButton } from "../login/SignInButton";
import AuthFlowContainer from "@/components/auth/AuthFlowContainer";

View File

@@ -174,6 +174,20 @@ export default async function RootLayout({
process.env.THEME_IS_DARK?.toLowerCase() === "true" ? "dark" : ""
}`}
>
{productGating === GatingType.PARTIAL && (
<div className="fixed top-0 left-0 right-0 z-50 bg-warning-100 text-warning-900 p-2 text-center">
<p className="text-sm font-medium">
Your account is pending payment!{" "}
<a
href="/admin/cloud-settings"
className="font-bold underline hover:text-warning-700 transition-colors"
>
Update your billing information
</a>{" "}
or access will be suspended soon.
</p>
</div>
)}
<UserProvider>
<ProviderContextProvider>
<SettingsProvider settings={combinedSettings}>

View File

@@ -27,6 +27,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
const authTypeMetadata = results[0] as AuthTypeMetadata | null;
const user = results[1] as User | null;
console.log("authTypeMetadata", authTypeMetadata);
const authDisabled = authTypeMetadata?.authType === "disabled";
const requiresVerification = authTypeMetadata?.requiresVerification;

View File

@@ -15,7 +15,6 @@ export function AnnouncementBanner() {
settings?.settings.notifications || []
);
console.log("notifications", localNotifications);
useEffect(() => {
const filteredNotifications = (
settings?.settings.notifications || []
@@ -33,7 +32,7 @@ export function AnnouncementBanner() {
const handleDismiss = async (notificationId: number) => {
try {
const response = await fetch(
`/api/notifications/${notificationId}/dismiss`,
`/api/settings/notifications/${notificationId}/dismiss`,
{
method: "POST",
}
@@ -62,12 +61,12 @@ export function AnnouncementBanner() {
{localNotifications
.filter((notification) => !notification.dismissed)
.map((notification) => {
return (
<div
key={notification.id}
className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto"
>
{notification.notif_type == "reindex" ? (
if (notification.notif_type == "reindex") {
return (
<div
key={notification.id}
className="absolute top-0 left-1/2 transform -translate-x-1/2 bg-blue-600 rounded-sm text-white px-4 pr-8 py-3 mx-auto"
>
<p className="text-center">
Your index is out of date - we strongly recommend updating
your search settings.{" "}
@@ -78,29 +77,24 @@ export function AnnouncementBanner() {
Update here
</Link>
</p>
) : notification.notif_type == "two_day_trial_ending" ? (
<p className="text-center">
Your trial is ending soon - submit your billing information to
continue using Danswer.{" "}
<Link
href="/admin/cloud-settings"
className="ml-2 underline cursor-pointer"
<button
onClick={() => handleDismiss(notification.id)}
className="absolute top-0 right-0 mt-2 mr-2"
aria-label="Dismiss"
>
<CustomTooltip
showTick
citation
delay={100}
content="Dismiss"
>
Update here
</Link>
</p>
) : null}
<button
onClick={() => handleDismiss(notification.id)}
className="absolute top-0 right-0 mt-2 mr-2"
aria-label="Dismiss"
>
<CustomTooltip showTick citation delay={100} content="Dismiss">
<XIcon className="h-5 w-5" />
</CustomTooltip>
</button>
</div>
);
<XIcon className="h-5 w-5" />
</CustomTooltip>
</button>
</div>
);
}
return null;
})}
</>
);