diff --git a/care/emr/api/viewsets/facility_flag.py b/care/emr/api/viewsets/facility_flag.py new file mode 100644 index 0000000000..57da9fac09 --- /dev/null +++ b/care/emr/api/viewsets/facility_flag.py @@ -0,0 +1,108 @@ +from django.db import transaction +from django_filters import rest_framework as filters +from rest_framework.decorators import action +from rest_framework.exceptions import PermissionDenied +from rest_framework.response import Response + +from care.emr.api.viewsets.base import ( + EMRBaseViewSet, + EMRCreateMixin, + EMRDestroyMixin, + EMRListMixin, + EMRRetrieveMixin, + EMRUpdateMixin, +) +from care.emr.resources.facility_flag.spec import ( + FacilityFlagCreateSpec, + FacilityFlagReadSpec, + FacilityFlagRetrieveSpec, + FacilityFlagUpdateSpec, +) +from care.facility.models import FacilityFlag +from care.security.authorization.base import AuthorizationController +from care.utils.registries.feature_flag import FlagNotFoundError, FlagRegistry, FlagType + + +class FacilityFlagFilters(filters.FilterSet): + facility = filters.UUIDFilter(field_name="facility__external_id") + flag = filters.CharFilter(field_name="flag", lookup_expr="iexact") + + +class FacilityFlagViewSet( + EMRCreateMixin, + EMRRetrieveMixin, + EMRUpdateMixin, + EMRListMixin, + EMRDestroyMixin, + EMRBaseViewSet, +): + database_model = FacilityFlag + pydantic_model = FacilityFlagCreateSpec + pydantic_update_model = FacilityFlagUpdateSpec + pydantic_read_model = FacilityFlagReadSpec + pydantic_retrieve_model = FacilityFlagRetrieveSpec + filterset_class = FacilityFlagFilters + filter_backends = [filters.DjangoFilterBackend] + + def authorize_create(self, instance): + if not AuthorizationController.call( + "can_write_facility_flag", self.request.user + ): + raise PermissionDenied( + "You do not have permission to create facility flags" + ) + + def authorize_update(self, request_obj, model_instance): + if not AuthorizationController.call( + "can_write_facility_flag", self.request.user + ): + raise PermissionDenied( + "You do not have permission to update facility flags" + ) + + def authorize_destroy(self, instance): + if not AuthorizationController.call( + "can_write_facility_flag", self.request.user + ): + raise PermissionDenied( + "You do not have permission to delete facility flags" + ) + + def get_queryset(self): + if not AuthorizationController.call( + "can_read_facility_flag", self.request.user + ): + raise PermissionDenied("You do not have permission to list facility flags") + return super().get_queryset() + + def perform_create(self, instance): + with transaction.atomic(): + super().perform_create(instance) + FlagRegistry.register(FlagType.FACILITY, instance.flag) + + def perform_destroy(self, instance): + with transaction.atomic(): + flag_name = instance.flag + super().perform_destroy(instance) + + still_used = FacilityFlag.objects.filter( + flag=flag_name, deleted=False + ).exists() + + if not still_used: + FlagRegistry.unregister(FlagType.FACILITY, flag_name) + + @action(detail=False, methods=["GET"], url_path="available-flags") + def available_flags(self, request): + if not AuthorizationController.call( + "can_read_facility_flag", self.request.user + ): + raise PermissionDenied("You do not have permission to view available flags") + + try: + flags = FlagRegistry.get_all_flags(FlagType.FACILITY) + return Response({"available_flags": list(flags)}) + except FlagNotFoundError: + return Response( + {"message": "No registered flag type 'facility' found."}, status=400 + ) diff --git a/care/emr/api/viewsets/user_flag.py b/care/emr/api/viewsets/user_flag.py new file mode 100644 index 0000000000..3e987daa6c --- /dev/null +++ b/care/emr/api/viewsets/user_flag.py @@ -0,0 +1,89 @@ +from django.db import transaction +from django_filters import rest_framework as filters +from rest_framework.decorators import action +from rest_framework.exceptions import PermissionDenied +from rest_framework.response import Response + +from care.emr.api.viewsets.base import ( + EMRBaseViewSet, + EMRCreateMixin, + EMRDestroyMixin, + EMRListMixin, + EMRRetrieveMixin, + EMRUpdateMixin, +) +from care.emr.resources.user_flag.spec import ( + UserFlagCreateSpec, + UserFlagReadSpec, + UserFlagRetrieveSpec, + UserFlagUpdateSpec, +) +from care.security.authorization.base import AuthorizationController +from care.users.models import UserFlag +from care.utils.registries.feature_flag import FlagNotFoundError, FlagRegistry, FlagType + + +class UserFlagFilters(filters.FilterSet): + user = filters.UUIDFilter(field_name="user__external_id") + flag = filters.CharFilter(field_name="flag", lookup_expr="iexact") + + +class UserFlagViewSet( + EMRCreateMixin, + EMRRetrieveMixin, + EMRUpdateMixin, + EMRListMixin, + EMRDestroyMixin, + EMRBaseViewSet, +): + database_model = UserFlag + pydantic_model = UserFlagCreateSpec + pydantic_update_model = UserFlagUpdateSpec + pydantic_read_model = UserFlagReadSpec + pydantic_retrieve_model = UserFlagRetrieveSpec + filterset_class = UserFlagFilters + filter_backends = [filters.DjangoFilterBackend] + + def authorize_create(self, instance): + if not AuthorizationController.call("can_write_user_flag", self.request.user): + raise PermissionDenied("You do not have permission to create user flags") + + def authorize_update(self, request_obj, model_instance): + if not AuthorizationController.call("can_write_user_flag", self.request.user): + raise PermissionDenied("You do not have permission to update user flags") + + def authorize_destroy(self, instance): + if not AuthorizationController.call("can_write_user_flag", self.request.user): + raise PermissionDenied("You do not have permission to delete user flags") + + def get_queryset(self): + if not AuthorizationController.call("can_read_user_flag", self.request.user): + raise PermissionDenied("You do not have permission to list user flags") + return super().get_queryset() + + def perform_create(self, instance): + with transaction.atomic(): + super().perform_create(instance) + FlagRegistry.register(FlagType.USER, instance.flag) + + def perform_destroy(self, instance): + with transaction.atomic(): + flag_name = instance.flag + super().perform_destroy(instance) + + still_used = UserFlag.objects.filter(flag=flag_name, deleted=False).exists() + + if not still_used: + FlagRegistry.unregister(FlagType.USER, flag_name) + + @action(detail=False, methods=["GET"], url_path="available-flags") + def available_flags(self, request): + if not AuthorizationController.call("can_read_user_flag", self.request.user): + raise PermissionDenied("You do not have permission to view available flags") + try: + flags = FlagRegistry.get_all_flags(FlagType.USER) + return Response({"available_flags": list(flags)}) + except FlagNotFoundError: + return Response( + {"message": "No registered flag type 'user' found."}, status=400 + ) diff --git a/care/emr/resources/facility_flag/__init__.py b/care/emr/resources/facility_flag/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/care/emr/resources/facility_flag/spec.py b/care/emr/resources/facility_flag/spec.py new file mode 100644 index 0000000000..7516f58ef6 --- /dev/null +++ b/care/emr/resources/facility_flag/spec.py @@ -0,0 +1,51 @@ +from pydantic import UUID4, field_validator + +from care.emr.resources.base import EMRResource +from care.emr.resources.facility.spec import FacilityBareMinimumSpec +from care.facility.models import FacilityFlag +from care.facility.models.facility import Facility +from care.utils.registries.feature_flag import FlagRegistry, FlagType +from care.utils.shortcuts import get_object_or_404 + + +class FacilityFlagBaseSpec(EMRResource): + __model__ = FacilityFlag + __exclude__ = ["facility"] + + id: UUID4 | None = None + flag: str + + +class FacilityFlagCreateSpec(FacilityFlagBaseSpec): + facility: UUID4 + + @field_validator("flag") + @classmethod + def validate_flag_name(cls, flag_name): + FlagRegistry.validate_flag_name(FlagType.FACILITY, flag_name) + return flag_name + + def perform_extra_deserialization(self, is_update, obj): + if not is_update: + obj.facility = get_object_or_404(Facility, external_id=self.facility) + + +class FacilityFlagUpdateSpec(FacilityFlagBaseSpec): + @field_validator("flag") + @classmethod + def validate_flag_name(cls, flag_name): + FlagRegistry.validate_flag_name(FlagType.FACILITY, flag_name) + return flag_name + + +class FacilityFlagReadSpec(FacilityFlagBaseSpec): + facility: dict + + @classmethod + def perform_extra_serialization(cls, mapping, obj): + mapping["id"] = obj.external_id + mapping["facility"] = FacilityBareMinimumSpec.serialize(obj.facility).to_json() + + +class FacilityFlagRetrieveSpec(FacilityFlagReadSpec): + pass diff --git a/care/emr/resources/user_flag/__init__.py b/care/emr/resources/user_flag/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/care/emr/resources/user_flag/spec.py b/care/emr/resources/user_flag/spec.py new file mode 100644 index 0000000000..c07af148df --- /dev/null +++ b/care/emr/resources/user_flag/spec.py @@ -0,0 +1,50 @@ +from pydantic import UUID4, field_validator + +from care.emr.resources.base import EMRResource +from care.emr.resources.user.spec import UserSpec +from care.users.models import User, UserFlag +from care.utils.registries.feature_flag import FlagRegistry, FlagType +from care.utils.shortcuts import get_object_or_404 + + +class UserFlagBaseSpec(EMRResource): + __model__ = UserFlag + __exclude__ = ["user"] + + id: UUID4 | None = None + flag: str + + +class UserFlagCreateSpec(UserFlagBaseSpec): + user: UUID4 + + @field_validator("flag") + @classmethod + def validate_flag_name(cls, flag_name): + FlagRegistry.validate_flag_name(FlagType.USER, flag_name) + return flag_name + + def perform_extra_deserialization(self, is_update, obj): + if not is_update: + obj.user = get_object_or_404(User, external_id=self.user) + + +class UserFlagUpdateSpec(UserFlagBaseSpec): + @field_validator("flag") + @classmethod + def validate_flag_name(cls, flag_name): + FlagRegistry.validate_flag_name(FlagType.USER, flag_name) + return flag_name + + +class UserFlagReadSpec(UserFlagBaseSpec): + user: dict + + @classmethod + def perform_extra_serialization(cls, mapping, obj): + mapping["id"] = obj.external_id + mapping["user"] = UserSpec.serialize(obj.user).to_json() + + +class UserFlagRetrieveSpec(UserFlagReadSpec): + pass diff --git a/care/emr/tests/test_facility_flag_api.py b/care/emr/tests/test_facility_flag_api.py new file mode 100644 index 0000000000..f761d2cc24 --- /dev/null +++ b/care/emr/tests/test_facility_flag_api.py @@ -0,0 +1,448 @@ +from django.urls import reverse + +from care.facility.models import FacilityFlag +from care.utils.registries.feature_flag import FlagRegistry, FlagType +from care.utils.tests.base import CareAPITestBase + + +class FacilityFlagAPITestCase(CareAPITestBase): + def setUp(self): + super().setUp() + # Register test flags + FlagRegistry.register(FlagType.FACILITY, "TEST_FACILITY_FLAG") + FlagRegistry.register(FlagType.FACILITY, "TEST_FACILITY_FLAG_2") + FlagRegistry.register(FlagType.FACILITY, "ENABLE_FEATURE_X") + + self.superuser = self.create_super_user(username="superuser") + self.normal_user = self.create_user(username="normaluser") + + self.facility = self.create_facility(name="Test Facility", user=self.superuser) + self.facility_2 = self.create_facility( + name="Test Facility 2", user=self.superuser + ) + + self.base_url = reverse("facility-flags-list") + + def get_detail_url(self, external_id): + return reverse("facility-flags-detail", kwargs={"external_id": external_id}) + + # ========== List Tests ========== + + def test_list_facility_flags_as_superuser(self): + """Test that superuser can list all facility flags""" + self.create_facility_flag(facility=self.facility, flag="TEST_FACILITY_FLAG") + self.create_facility_flag(facility=self.facility_2, flag="TEST_FACILITY_FLAG_2") + + self.client.force_authenticate(user=self.superuser) + response = self.client.get(self.base_url) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["count"], 2) + + def test_list_facility_flags_filtered_by_facility(self): + """Test that superuser can filter facility flags by facility""" + self.create_facility_flag(facility=self.facility, flag="TEST_FACILITY_FLAG") + self.create_facility_flag(facility=self.facility_2, flag="TEST_FACILITY_FLAG_2") + + self.client.force_authenticate(user=self.superuser) + response = self.client.get( + f"{self.base_url}?facility={self.facility.external_id}" + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["count"], 1) + self.assertEqual(data["results"][0]["flag"], "TEST_FACILITY_FLAG") + + def test_list_facility_flags_as_normal_user(self): + """Test that normal user cannot list facility flags""" + self.create_facility_flag(facility=self.facility, flag="TEST_FACILITY_FLAG") + + self.client.force_authenticate(user=self.normal_user) + response = self.client.get(self.base_url) + self.assertEqual(response.status_code, 403) + + def test_list_facility_flags_unauthenticated(self): + """Test that unauthenticated user cannot list facility flags""" + response = self.client.get(self.base_url) + # get_queryset authorization check happens before authentication check + self.assertEqual(response.status_code, 403) + + def test_filter_facility_flags_by_flag_name(self): + """Test filtering facility flags by flag name (case-insensitive)""" + self.create_facility_flag(facility=self.facility, flag="TEST_FACILITY_FLAG") + self.create_facility_flag(facility=self.facility_2, flag="TEST_FACILITY_FLAG_2") + + self.client.force_authenticate(user=self.superuser) + + # Test exact case + response = self.client.get(f"{self.base_url}?flag=TEST_FACILITY_FLAG") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["count"], 1) + + # Test case-insensitive + response = self.client.get(f"{self.base_url}?flag=test_facility_flag") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["count"], 1) + + def test_list_multiple_flags_same_facility(self): + """Test listing multiple flags for the same facility""" + self.create_facility_flag(facility=self.facility, flag="TEST_FACILITY_FLAG") + self.create_facility_flag(facility=self.facility, flag="TEST_FACILITY_FLAG_2") + + self.client.force_authenticate(user=self.superuser) + response = self.client.get( + f"{self.base_url}?facility={self.facility.external_id}" + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["count"], 2) + + # ========== Create Tests ========== + + def test_create_facility_flag_as_superuser(self): + """Test that superuser can create facility flag""" + self.client.force_authenticate(user=self.superuser) + response = self.client.post( + self.base_url, + { + "flag": "TEST_FACILITY_FLAG", + "facility": self.facility.external_id, + }, + format="json", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "TEST_FACILITY_FLAG") + + # Verify flag was created + self.assertTrue( + FacilityFlag.objects.filter( + facility=self.facility, flag="TEST_FACILITY_FLAG", deleted=False + ).exists() + ) + + def test_create_facility_flag_as_normal_user(self): + """Test that normal user cannot create facility flag""" + self.client.force_authenticate(user=self.normal_user) + response = self.client.post( + self.base_url, + { + "flag": "TEST_FACILITY_FLAG", + "facility": self.facility.external_id, + }, + format="json", + ) + self.assertEqual(response.status_code, 403) + + def test_create_facility_flag_with_invalid_facility_uuid(self): + """Test that creating facility flag with invalid facility UUID fails""" + self.client.force_authenticate(user=self.superuser) + response = self.client.post( + self.base_url, + { + "flag": "TEST_FACILITY_FLAG", + "facility": "00000000-0000-0000-0000-000000000000", + }, + format="json", + ) + # Invalid UUID causes validation error (400), not 404 + self.assertEqual(response.status_code, 400) + + def test_create_facility_flag_missing_required_fields(self): + """Test that creating facility flag without required fields fails""" + self.client.force_authenticate(user=self.superuser) + + # Missing flag + response = self.client.post( + self.base_url, + {"facility": self.facility.external_id}, + format="json", + ) + self.assertEqual(response.status_code, 400) + + # Missing facility + response = self.client.post( + self.base_url, + {"flag": "TEST_FACILITY_FLAG"}, + format="json", + ) + self.assertEqual(response.status_code, 400) + + def test_create_same_flag_different_facilities(self): + """Test that same flag can be created for different facilities""" + self.client.force_authenticate(user=self.superuser) + + # Create flag for facility 1 + response1 = self.client.post( + self.base_url, + { + "flag": "ENABLE_FEATURE_X", + "facility": self.facility.external_id, + }, + format="json", + ) + self.assertEqual(response1.status_code, 200) + + # Create same flag for facility 2 - should succeed + response2 = self.client.post( + self.base_url, + { + "flag": "ENABLE_FEATURE_X", + "facility": self.facility_2.external_id, + }, + format="json", + ) + self.assertEqual(response2.status_code, 200) + + # ========== Retrieve Tests ========== + + def test_retrieve_facility_flag_as_superuser(self): + """Test that superuser can retrieve facility flag""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + response = self.client.get(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "TEST_FACILITY_FLAG") + self.assertIn("facility", response.data) + + def test_retrieve_facility_flag_as_normal_user(self): + """Test that normal user cannot retrieve facility flag (queryset filtered)""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.normal_user) + response = self.client.get(self.get_detail_url(flag.external_id)) + # get_queryset blocks access, returns 403 + self.assertEqual(response.status_code, 403) + + def test_retrieve_non_existent_facility_flag(self): + """Test that retrieving non-existent facility flag returns 404""" + self.client.force_authenticate(user=self.superuser) + response = self.client.get( + self.get_detail_url("00000000-0000-0000-0000-000000000000") + ) + self.assertEqual(response.status_code, 404) + + def test_retrieve_deleted_facility_flag(self): + """Test that retrieving deleted facility flag returns 404""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + flag.deleted = True + flag.save() + + self.client.force_authenticate(user=self.superuser) + response = self.client.get(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 404) + + # ========== Update Tests ========== + + def test_update_facility_flag_as_superuser(self): + """Test that superuser can update facility flag""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + response = self.client.put( + self.get_detail_url(flag.external_id), + {"flag": "TEST_FACILITY_FLAG_2"}, + format="json", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "TEST_FACILITY_FLAG_2") + + # Verify in database + flag.refresh_from_db() + self.assertEqual(flag.flag, "TEST_FACILITY_FLAG_2") + + def test_update_facility_flag_as_normal_user(self): + """Test that normal user cannot update facility flag""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.normal_user) + response = self.client.put( + self.get_detail_url(flag.external_id), + {"flag": "TEST_FACILITY_FLAG_2"}, + format="json", + ) + self.assertEqual(response.status_code, 403) + + def test_partial_update_facility_flag(self): + """Test that partial update (PATCH) works""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + response = self.client.patch( + self.get_detail_url(flag.external_id), + {"flag": "ENABLE_FEATURE_X"}, + format="json", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "ENABLE_FEATURE_X") + + # ========== Delete Tests ========== + + def test_delete_facility_flag_as_superuser(self): + """Test that superuser can delete facility flag""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify soft delete + flag.refresh_from_db() + self.assertTrue(flag.deleted) + + def test_delete_facility_flag_as_normal_user(self): + """Test that normal user cannot delete facility flag""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.normal_user) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 403) + + def test_delete_already_deleted_facility_flag(self): + """Test that deleting already deleted flag returns 404""" + flag = self.create_facility_flag( + facility=self.facility, flag="TEST_FACILITY_FLAG" + ) + flag.deleted = True + flag.save() + + self.client.force_authenticate(user=self.superuser) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 404) + + def test_delete_unregisters_flag(self): + """Test that deleting flag unregisters it from registry""" + # Create a unique flag + unique_flag = "UNIQUE_FACILITY_FLAG_DELETE" + FlagRegistry.register(FlagType.FACILITY, unique_flag) + flag = self.create_facility_flag(facility=self.facility, flag=unique_flag) + + self.client.force_authenticate(user=self.superuser) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify flag is unregistered + flags = FlagRegistry.get_all_flags(FlagType.FACILITY) + self.assertNotIn(unique_flag, flags) + + # ========== Available Flags Tests ========== + + def test_available_flags_as_superuser(self): + """Test that superuser can view available flags""" + self.client.force_authenticate(user=self.superuser) + response = self.client.get(f"{self.base_url}available-flags/") + self.assertEqual(response.status_code, 200) + self.assertIn("available_flags", response.data) + self.assertIn("TEST_FACILITY_FLAG", response.data["available_flags"]) + self.assertIn("TEST_FACILITY_FLAG_2", response.data["available_flags"]) + + def test_available_flags_as_normal_user(self): + """Test that normal user cannot view available flags""" + self.client.force_authenticate(user=self.normal_user) + response = self.client.get(f"{self.base_url}available-flags/") + self.assertEqual(response.status_code, 403) + + # ========== Helper Methods ========== + + def create_facility_flag(self, facility, flag): + # Register flag (idempotent operation, safe to call multiple times) + FlagRegistry.register(FlagType.FACILITY, flag) + return FacilityFlag.objects.create(facility=facility, flag=flag) + + # ========== Multi-Facility Flag Deletion Tests ========== + + def test_delete_flag_does_not_unregister_when_other_facilities_have_it(self): + """Test that deleting flag from one facility doesn't unregister if other facilities have it""" + # Create same flag for two facilities + flag1 = self.create_facility_flag( + facility=self.facility, flag="SHARED_FACILITY_FLAG" + ) + flag2 = self.create_facility_flag( + facility=self.facility_2, flag="SHARED_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + + # Delete flag from first facility + response = self.client.delete(self.get_detail_url(flag1.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify flag is still registered (because facility_2 still has it) + flags = FlagRegistry.get_all_flags(FlagType.FACILITY) + self.assertIn("SHARED_FACILITY_FLAG", flags) + + # Verify second facility's flag still works + flag2.refresh_from_db() + self.assertEqual(flag2.flag, "SHARED_FACILITY_FLAG") + self.assertFalse(flag2.deleted) + + def test_delete_last_facility_with_flag_unregisters_it(self): + """Test that deleting the last facility with a flag unregisters it""" + # Create flag for only one facility + flag = self.create_facility_flag( + facility=self.facility, flag="UNIQUE_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + + # Delete the only instance + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify flag is unregistered (no other facilities have it) + flags = FlagRegistry.get_all_flags(FlagType.FACILITY) + self.assertNotIn("UNIQUE_FACILITY_FLAG", flags) + + def test_delete_multiple_facilities_with_same_flag_sequence(self): + """Test deleting flag from multiple facilities in sequence""" + # Create third facility for this test + facility_3 = self.create_facility(name="Test Facility 3", user=self.superuser) + + # Create same flag for three facilities + flag1 = self.create_facility_flag( + facility=self.facility, flag="MULTI_FACILITY_FLAG" + ) + flag2 = self.create_facility_flag( + facility=self.facility_2, flag="MULTI_FACILITY_FLAG" + ) + flag3 = self.create_facility_flag( + facility=facility_3, flag="MULTI_FACILITY_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + + # Delete first facility's flag + response = self.client.delete(self.get_detail_url(flag1.external_id)) + self.assertEqual(response.status_code, 204) + flags = FlagRegistry.get_all_flags(FlagType.FACILITY) + self.assertIn("MULTI_FACILITY_FLAG", flags) # Still registered + + # Delete second facility's flag + response = self.client.delete(self.get_detail_url(flag2.external_id)) + self.assertEqual(response.status_code, 204) + flags = FlagRegistry.get_all_flags(FlagType.FACILITY) + self.assertIn( + "MULTI_FACILITY_FLAG", flags + ) # Still registered (facility_3 has it) + + # Delete last facility's flag + response = self.client.delete(self.get_detail_url(flag3.external_id)) + self.assertEqual(response.status_code, 204) + flags = FlagRegistry.get_all_flags(FlagType.FACILITY) + self.assertNotIn("MULTI_FACILITY_FLAG", flags) # Now unregistered diff --git a/care/emr/tests/test_user_flag_api.py b/care/emr/tests/test_user_flag_api.py new file mode 100644 index 0000000000..4144fd227d --- /dev/null +++ b/care/emr/tests/test_user_flag_api.py @@ -0,0 +1,367 @@ +from django.urls import reverse + +from care.users.models import UserFlag +from care.utils.registries.feature_flag import FlagRegistry, FlagType +from care.utils.tests.base import CareAPITestBase + + +class UserFlagAPITestCase(CareAPITestBase): + def setUp(self): + super().setUp() + # Register test flags + FlagRegistry.register(FlagType.USER, "TEST_FLAG") + FlagRegistry.register(FlagType.USER, "TEST_FLAG_2") + FlagRegistry.register(FlagType.USER, "BETA_FEATURES") + + self.superuser = self.create_super_user(username="superuser") + self.normal_user = self.create_user(username="normaluser") + self.target_user = self.create_user(username="targetuser") + + self.base_url = reverse("user-flags-list") + + def get_detail_url(self, external_id): + return reverse("user-flags-detail", kwargs={"external_id": external_id}) + + # ========== List Tests ========== + + def test_list_user_flags_as_superuser(self): + """Test that superuser can list all user flags""" + self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + self.create_user_flag(user=self.normal_user, flag="TEST_FLAG_2") + + self.client.force_authenticate(user=self.superuser) + response = self.client.get(self.base_url) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["count"], 2) + + def test_list_user_flags_as_normal_user(self): + """Test that normal user cannot list user flags""" + self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.normal_user) + response = self.client.get(self.base_url) + self.assertEqual(response.status_code, 403) + + def test_list_user_flags_unauthenticated(self): + """Test that unauthenticated user cannot list user flags""" + response = self.client.get(self.base_url) + # get_queryset authorization check happens before authentication check + self.assertEqual(response.status_code, 403) + + def test_filter_user_flags_by_user(self): + """Test filtering user flags by user""" + self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + self.create_user_flag(user=self.target_user, flag="TEST_FLAG_2") + self.create_user_flag(user=self.normal_user, flag="BETA_FEATURES") + + self.client.force_authenticate(user=self.superuser) + response = self.client.get( + f"{self.base_url}?user={self.target_user.external_id}" + ) + self.assertEqual(response.status_code, 200) + data = response.json() + self.assertEqual(data["count"], 2) + + def test_filter_user_flags_by_flag_name(self): + """Test filtering user flags by flag name (case-insensitive)""" + self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + self.create_user_flag(user=self.normal_user, flag="TEST_FLAG_2") + + self.client.force_authenticate(user=self.superuser) + + # Test exact case + response = self.client.get(f"{self.base_url}?flag=TEST_FLAG") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["count"], 1) + + # Test case-insensitive + response = self.client.get(f"{self.base_url}?flag=test_flag") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["count"], 1) + + # ========== Create Tests ========== + + def test_create_user_flag_as_superuser(self): + """Test that superuser can create user flags""" + self.client.force_authenticate(user=self.superuser) + response = self.client.post( + self.base_url, + {"flag": "TEST_FLAG", "user": self.target_user.external_id}, + format="json", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "TEST_FLAG") + + # Verify flag was created and registered + self.assertTrue( + UserFlag.objects.filter( + user=self.target_user, flag="TEST_FLAG", deleted=False + ).exists() + ) + + def test_create_user_flag_as_normal_user(self): + """Test that normal user cannot create user flags""" + self.client.force_authenticate(user=self.normal_user) + response = self.client.post( + self.base_url, + {"flag": "TEST_FLAG", "user": self.target_user.external_id}, + format="json", + ) + self.assertEqual(response.status_code, 403) + + def test_create_user_flag_with_invalid_user_uuid(self): + """Test that creating user flag with invalid user UUID fails""" + self.client.force_authenticate(user=self.superuser) + response = self.client.post( + self.base_url, + {"flag": "TEST_FLAG", "user": "00000000-0000-0000-0000-000000000000"}, + format="json", + ) + # Invalid UUID causes validation error (400), not 404 + self.assertEqual(response.status_code, 400) + + def test_create_user_flag_missing_required_fields(self): + """Test that creating user flag without required fields fails""" + self.client.force_authenticate(user=self.superuser) + + # Missing flag + response = self.client.post( + self.base_url, + {"user": self.target_user.external_id}, + format="json", + ) + self.assertEqual(response.status_code, 400) + + # Missing user + response = self.client.post( + self.base_url, + {"flag": "TEST_FLAG"}, + format="json", + ) + self.assertEqual(response.status_code, 400) + + # ========== Retrieve Tests ========== + + def test_retrieve_user_flag_as_superuser(self): + """Test that superuser can retrieve user flag""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.superuser) + response = self.client.get(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "TEST_FLAG") + self.assertIn("user", response.data) + + def test_retrieve_user_flag_as_normal_user(self): + """Test that normal user cannot retrieve user flag (queryset filtered)""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.normal_user) + response = self.client.get(self.get_detail_url(flag.external_id)) + # get_queryset blocks access, returns 403 + self.assertEqual(response.status_code, 403) + + def test_retrieve_non_existent_user_flag(self): + """Test that retrieving non-existent user flag returns 404""" + self.client.force_authenticate(user=self.superuser) + response = self.client.get( + self.get_detail_url("00000000-0000-0000-0000-000000000000") + ) + self.assertEqual(response.status_code, 404) + + def test_retrieve_deleted_user_flag(self): + """Test that retrieving deleted user flag returns 404""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + flag.deleted = True + flag.save() + + self.client.force_authenticate(user=self.superuser) + response = self.client.get(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 404) + + # ========== Update Tests ========== + + def test_update_user_flag_as_superuser(self): + """Test that superuser can update user flag""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.superuser) + response = self.client.put( + self.get_detail_url(flag.external_id), + {"flag": "TEST_FLAG_2"}, + format="json", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "TEST_FLAG_2") + + # Verify in database + flag.refresh_from_db() + self.assertEqual(flag.flag, "TEST_FLAG_2") + + def test_update_user_flag_as_normal_user(self): + """Test that normal user cannot update user flag""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.normal_user) + response = self.client.put( + self.get_detail_url(flag.external_id), + {"flag": "TEST_FLAG_2"}, + format="json", + ) + self.assertEqual(response.status_code, 403) + + def test_partial_update_user_flag(self): + """Test that partial update (PATCH) works""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.superuser) + response = self.client.patch( + self.get_detail_url(flag.external_id), + {"flag": "BETA_FEATURES"}, + format="json", + ) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["flag"], "BETA_FEATURES") + + # ========== Delete Tests ========== + + def test_delete_user_flag_as_superuser(self): + """Test that superuser can delete user flag""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.superuser) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify soft delete + flag.refresh_from_db() + self.assertTrue(flag.deleted) + + def test_delete_user_flag_as_normal_user(self): + """Test that normal user cannot delete user flag""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + + self.client.force_authenticate(user=self.normal_user) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 403) + + def test_delete_already_deleted_user_flag(self): + """Test that deleting already deleted flag returns 404""" + flag = self.create_user_flag(user=self.target_user, flag="TEST_FLAG") + flag.deleted = True + flag.save() + + self.client.force_authenticate(user=self.superuser) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 404) + + def test_delete_unregisters_flag(self): + """Test that deleting flag unregisters it from registry""" + # Create a unique flag + unique_flag = "UNIQUE_TEST_FLAG_DELETE" + FlagRegistry.register(FlagType.USER, unique_flag) + flag = self.create_user_flag(user=self.target_user, flag=unique_flag) + + self.client.force_authenticate(user=self.superuser) + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify flag is unregistered + flags = FlagRegistry.get_all_flags(FlagType.USER) + self.assertNotIn(unique_flag, flags) + + # ========== Available Flags Tests ========== + + def test_available_flags_as_superuser(self): + """Test that superuser can view available flags""" + self.client.force_authenticate(user=self.superuser) + response = self.client.get(f"{self.base_url}available-flags/") + self.assertEqual(response.status_code, 200) + self.assertIn("available_flags", response.data) + self.assertIn("TEST_FLAG", response.data["available_flags"]) + self.assertIn("TEST_FLAG_2", response.data["available_flags"]) + + def test_available_flags_as_normal_user(self): + """Test that normal user cannot view available flags""" + self.client.force_authenticate(user=self.normal_user) + response = self.client.get(f"{self.base_url}available-flags/") + self.assertEqual(response.status_code, 403) + + # ========== Helper Methods ========== + + def create_user_flag(self, user, flag): + # Register flag (idempotent operation, safe to call multiple times) + FlagRegistry.register(FlagType.USER, flag) + return UserFlag.objects.create(user=user, flag=flag) + + # ========== Multi-User Flag Deletion Tests ========== + + def test_delete_flag_does_not_unregister_when_other_users_have_it(self): + """Test that deleting flag from one user doesn't unregister if other users have it""" + # Create same flag for two users + flag1 = self.create_user_flag(user=self.target_user, flag="SHARED_FLAG") + flag2 = self.create_user_flag(user=self.normal_user, flag="SHARED_FLAG") + + self.client.force_authenticate(user=self.superuser) + + # Delete flag from first user + response = self.client.delete(self.get_detail_url(flag1.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify flag is still registered (because normal_user still has it) + flags = FlagRegistry.get_all_flags(FlagType.USER) + self.assertIn("SHARED_FLAG", flags) + + # Verify second user's flag still works + flag2.refresh_from_db() + self.assertEqual(flag2.flag, "SHARED_FLAG") + self.assertFalse(flag2.deleted) + + def test_delete_last_user_with_flag_unregisters_it(self): + """Test that deleting the last user with a flag unregisters it""" + # Create flag for only one user + flag = self.create_user_flag(user=self.target_user, flag="UNIQUE_USER_FLAG") + + self.client.force_authenticate(user=self.superuser) + + # Delete the only instance + response = self.client.delete(self.get_detail_url(flag.external_id)) + self.assertEqual(response.status_code, 204) + + # Verify flag is unregistered (no other users have it) + flags = FlagRegistry.get_all_flags(FlagType.USER) + self.assertNotIn("UNIQUE_USER_FLAG", flags) + + def test_delete_multiple_users_with_same_flag_sequence(self): + """Test deleting flag from multiple users in sequence""" + # Create same flag for three users + superuser_flag = self.create_user_flag( + user=self.superuser, flag="MULTI_USER_FLAG" + ) + target_flag = self.create_user_flag( + user=self.target_user, flag="MULTI_USER_FLAG" + ) + normal_flag = self.create_user_flag( + user=self.normal_user, flag="MULTI_USER_FLAG" + ) + + self.client.force_authenticate(user=self.superuser) + + # Delete first user's flag + response = self.client.delete(self.get_detail_url(target_flag.external_id)) + self.assertEqual(response.status_code, 204) + flags = FlagRegistry.get_all_flags(FlagType.USER) + self.assertIn("MULTI_USER_FLAG", flags) # Still registered + + # Delete second user's flag + response = self.client.delete(self.get_detail_url(normal_flag.external_id)) + self.assertEqual(response.status_code, 204) + flags = FlagRegistry.get_all_flags(FlagType.USER) + self.assertIn("MULTI_USER_FLAG", flags) # Still registered (superuser has it) + + # Delete last user's flag + response = self.client.delete(self.get_detail_url(superuser_flag.external_id)) + self.assertEqual(response.status_code, 204) + flags = FlagRegistry.get_all_flags(FlagType.USER) + self.assertNotIn("MULTI_USER_FLAG", flags) # Now unregistered diff --git a/care/security/authorization/__init__.py b/care/security/authorization/__init__.py index 283f19746b..545d8d95fe 100644 --- a/care/security/authorization/__init__.py +++ b/care/security/authorization/__init__.py @@ -7,6 +7,7 @@ from .device import * # noqa from .encounter import * # noqa from .facility import * # noqa +from .facility_flag import * # noqa from .facility_location import * # noqa from .facilityorganization import * # noqa from .healthcare_service import * # noqa @@ -33,3 +34,4 @@ from .template import * # noqa from .token import * # noqa from .user import * # noqa +from .user_flag import * # noqa diff --git a/care/security/authorization/facility_flag.py b/care/security/authorization/facility_flag.py new file mode 100644 index 0000000000..e695d46c55 --- /dev/null +++ b/care/security/authorization/facility_flag.py @@ -0,0 +1,21 @@ +from care.security.authorization import AuthorizationController +from care.security.authorization.base import AuthorizationHandler + + +class FacilityFlagAccess(AuthorizationHandler): + def can_read_facility_flag(self, user): + """ + Check if the user has permission to read facility flags + Only superusers can read facility flags + """ + return user.is_superuser + + def can_write_facility_flag(self, user): + """ + Check if the user has permission to write facility flags + Only superusers can write facility flags + """ + return user.is_superuser + + +AuthorizationController.register_internal_controller(FacilityFlagAccess) diff --git a/care/security/authorization/user_flag.py b/care/security/authorization/user_flag.py new file mode 100644 index 0000000000..d11a9a3ed5 --- /dev/null +++ b/care/security/authorization/user_flag.py @@ -0,0 +1,21 @@ +from care.security.authorization import AuthorizationController +from care.security.authorization.base import AuthorizationHandler + + +class UserFlagAccess(AuthorizationHandler): + def can_read_user_flag(self, user): + """ + Check if the user has permission to read user flags + Only superusers can read user flags + """ + return user.is_superuser + + def can_write_user_flag(self, user): + """ + Check if the user has permission to write user flags + Only superusers can write user flags + """ + return user.is_superuser + + +AuthorizationController.register_internal_controller(UserFlagAccess) diff --git a/config/api_router.py b/config/api_router.py index fc939d527a..462198b09f 100644 --- a/config/api_router.py +++ b/config/api_router.py @@ -29,6 +29,7 @@ FacilityUsersViewSet, FacilityViewSet, ) +from care.emr.api.viewsets.facility_flag import FacilityFlagViewSet from care.emr.api.viewsets.facility_organization import ( FacilityOrganizationUsersViewSet, FacilityOrganizationViewSet, @@ -106,6 +107,7 @@ from care.emr.api.viewsets.tag_config import TagConfigViewSet from care.emr.api.viewsets.totp import TOTPViewSet from care.emr.api.viewsets.user import UserViewSet +from care.emr.api.viewsets.user_flag import UserFlagViewSet from care.emr.api.viewsets.valueset import ValueSetViewSet from care.security.api.viewsets.permissions import PermissionViewSet from care.security.api.viewsets.roles import RoleViewSet @@ -115,6 +117,10 @@ router.register("users", UserViewSet, basename="users") +router.register("user_flags", UserFlagViewSet, basename="user-flags") + +router.register("facility_flags", FacilityFlagViewSet, basename="facility-flags") + router.register("plug_config", PlugConfigViewset, basename="plug_configs") user_nested_router = NestedSimpleRouter(router, r"users", lookup="users")