您最多选择25个主题 主题必须以中文或者字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

371 行
12 KiB

import argparse
import os
import pathlib
from typing import Dict, List, Tuple
import numpy as np
import streamlit as st
import PIL
from PIL import ImageFont
from PIL.Image import Image
from PIL.ImageDraw import ImageDraw
from pyrception_utils import PyrceptionDataset
st.set_page_config(layout="wide")
#--------------------------------Custom component-----------------------------------------------------------------------
import streamlit.components.v1 as components
root_dir = os.path.dirname(os.path.abspath(__file__))
build_dir_slider = os.path.join(root_dir, "custom_components/slider/build")
build_dir_page_selector = os.path.join(root_dir, "custom_components/pageselector/build")
build_dir_go_to = os.path.join(root_dir, "custom_components/goto/build")
build_dir_item_selector = os.path.join(root_dir, "custom_components/itemselector/build")
_discrete_slider = components.declare_component(
"discrete_slider",
path=build_dir_slider
)
_page_selector = components.declare_component(
"page_selector",
path=build_dir_page_selector
)
_go_to = components.declare_component(
"go_to",
path=build_dir_go_to
)
_item_selector = components.declare_component(
"item_selector",
path=build_dir_item_selector
)
def discrete_slider(greeting, name, key, default=0):
return _discrete_slider(greeting=greeting, name=name, default=default, key=key)
def page_selector(startAt, incrementAmt, key=0):
return _page_selector(startAt=startAt, incrementAmt=incrementAmt, key=key, default=0)
def go_to(key=0):
return _go_to(key=key, default=0)
def item_selector(startAt, incrementAmt, datasetSize, key=0):
return _item_selector(startAt=startAt, incrementAmt=incrementAmt, datasetSize=datasetSize, key=key, default=0)
#-------------------------------------END-------------------------------------------------------------------------------
def list_datasets(path) -> List:
"""
Lists the datasets in a diretory.
:param path: path to a directory that contains dataset folders
:type str:
:return: A list of dataset directories.
:rtype: List
"""
datasets = []
for item in os.listdir(path):
if os.path.isdir(os.path.join(path, item)) and item != "Unity":
datasets.append(item)
return datasets
def frame_selector_ui(dataset: PyrceptionDataset) -> int:
"""
Frame selector streamlist widget to select which frame in the dataset to display
:param dataset: the PyrceptionDataset
:type PyrceptionDataset:
:return: The image index
:rtype: int
"""
st.sidebar.markdown("# Image set")
num_images = len(dataset)
image_index = st.sidebar.slider("Image number", 0, num_images - 1)
return image_index
def draw_image_with_boxes(
image: Image,
classes: Dict,
labels: List,
boxes: List[List],
colors: Dict,
header: str,
description: str,
):
"""
Draws an image in streamlit with labels and bounding boxes.
:param image: the PIL image
:type PIL:
:param classes: the class dictionary
:type Dict:
:param labels: list of integer object labels for the frame
:type List:
:param boxes: List of bounding boxes (as a List of coordinates) for the frame
:type List[List]:
:param colors: class colors
:type Dict:
:param header: Image header
:type str:
:param description: Image description
:type str:
"""
image = image.copy()
image_draw = ImageDraw(image)
# draw bounding boxes
path_to_font = pathlib.Path(__file__).parent.absolute()
font = ImageFont.truetype(f"{path_to_font}/NairiNormal-m509.ttf", 15)
for label, box in zip(labels, boxes):
label = label - 1
class_name = classes[label]
image_draw.rectangle(box, outline=colors[class_name], width=2)
image_draw.text(
(box[0], box[1]), class_name, font=font, fill=colors[class_name]
)
#st.subheader(header)
#st.markdown(description)
#st.image(image, use_column_width=True)
return image
def draw_image_with_semantic_segmentation(
image: Image,
height: int,
width: int,
segmentation: Image,
header: str,
description: str,
):
"""
Draws an image in streamlit with labels and bounding boxes.
:param image: the PIL image
:type PIL:
:param height: height of the image
:type int:
:param width: width of the image
:type int:
:param segmentation: Segmentation Image
:type PIL:
:param header: Image header
:type str:
:param description: Image description
:type str:
"""
# image_draw = ImageDraw(segmentation)
rgba = np.array(segmentation.copy().convert("RGBA"))
r,g,b,a = rgba.T
black_areas = (r == 0) & (b == 0) & (g == 0) & (a == 255)
other_areas = (r != 0) | (b != 0) | (g != 0)
rgba[...,0:4][black_areas.T] = (0,0,0,0)
rgba[...,-1][other_areas.T] = int(0.6 * 255)
foreground = PIL.Image.fromarray(rgba)
image = image.copy()
image.paste(foreground,(0,0),foreground)
return image
def draw_image_stacked(
image: Image,
classes: Dict,
labels: List,
boxes: List[List],
colors: Dict,
header: str,
description: str,
height: int,
width: int,
segmentation: Image,
):
image = image.copy()
color_intensity = st.sidebar.slider('color intensity 2 (%)', 0, 100, 65);
alpha = color_intensity / 100;
for x in range(0, width - 1):
for y in range(0, height - 1):
(seg_r, seg_g, seg_b) = segmentation.getpixel((x, y))
(r, g, b) = image.getpixel((x, y))
# if it isn't a black pixel in the segmentation image then highlight it with the segmentation color
if seg_r != 0 or seg_g != 0 or seg_b != 0:
image.putpixel((x, y),
(int((1 - alpha) * r + alpha * seg_r),
int((1 - alpha) * g + alpha * seg_g),
int((1 - alpha) * b + alpha * seg_b)))
image_draw = ImageDraw(image)
# draw bounding boxes
path_to_font = pathlib.Path(__file__).parent.absolute()
font = ImageFont.truetype(f"{path_to_font}/NairiNormal-m509.ttf", 15)
for label, box in zip(labels, boxes):
label = label - 1
class_name = classes[label]
image_draw.rectangle(box, outline=colors[class_name], width=2)
image_draw.text(
(box[0], box[1]), class_name, font=font, fill=colors[class_name]
)
st.subheader(header)
st.markdown(description)
st.image(image, use_column_width=True)
def display_count(
header: str,
description: str,
):
"""
:param header: Image header
:type str:
:param description: Image description
:type str:
"""
return
@st.cache(show_spinner=True, allow_output_mutation=True)
def load_perception_dataset(path: str) -> Tuple:
"""
Loads the perception dataset in the cache and caches the random bounding box color scheme.
:param path: Dataset path
:type str:
:return: A tuple with the colors and PyrceptionDataset object as (colors, dataset)
:rtype: Tuple
"""
dataset = PyrceptionDataset(data_dir=path)
classes = dataset.classes
colors = {name: tuple(np.random.randint(128, 255, size=3)) for name in classes}
return colors, dataset
def preview_dataset(base_dataset_dir: str):
"""
Adds streamlit components to the app to construct the dataset preview.
:param base_dataset_dir: The directory that contains the perceptions datasets.
:type str:
"""
#st.markdown("# Synthetic Dataset Preview\n ## Unity Technologies ")
dataset_name = st.sidebar.selectbox(
"Please select a dataset...", list_datasets(base_dataset_dir)
)
num_rows = 5
if dataset_name is not None:
colors, dataset = load_perception_dataset(
os.path.join(base_dataset_dir, dataset_name)
)
#classes = dataset.classes
#st.sidebar.selectbox(
# "hello", classes
#)
#image_index = frame_selector_ui(dataset)
#image, segmentation, target = dataset[image_index]
#labels = target["labels"]
#boxes = target["boxes"]
#st.image(image, use_column_width=True)
#draw_image_stacked(
# image, classes, labels, boxes, colors, "Bounding Boxes Preview", "", dataset.metadata.image_size[0], dataset.metadata.image_size[1], segmentation
#)
#draw_image_with_boxes(
# image, classes, labels, boxes, colors, "Bounding Boxes Preview", ""
#)
#image = draw_image_with_semantic_segmentation(
# image, dataset.metadata.image_size[0], dataset.metadata.image_size[1], segmentation, "Semantic Segmentation Preview", ""
#)
#image = draw_image_with_boxes(
# image, classes, labels, boxes, colors, "Bounding Boxes Preview", ""
#)
grid_view(num_rows, colors, dataset)
def sidebar():
return None
def navbar():
return None
def grid_view(num_rows, colors, dataset):
header = st.beta_columns([2/3, 1/3])
num_cols = header[1].slider(label="Image per row: ", min_value=1, max_value=5, step=1, value=3)
with header[0]:
start_at_2 = item_selector(0, num_cols * num_rows, len(dataset))
#inner_cols = st.beta_columns([0.1, 0.0001])
cols = st.beta_columns(num_cols)
semantic_segmentation = st.sidebar.checkbox("Semantic Segmentation", key="ss")
bounding_boxes_2d = st.sidebar.checkbox("Bounding Boxes", key="bb2d")
#app_state = st.experimental_get_query_params()
#if "start_at" in app_state:
# start_at = int(app_state["start_at"][0])
#else:
# start_at = 0
#if inner_cols[1].button('>'):
# overflow_image_count = len(dataset) % (num_cols * num_rows)
# overflow_image_count = (num_cols * num_rows) if overflow_image_count == 0 else overflow_image_count
# start_at = min(start_at + num_cols * num_rows, len(dataset)-overflow_image_count)
#if inner_cols[0].button('<'):
# start_at = max(0, start_at - num_cols * num_rows)
#st.experimental_set_query_params(start_at=start_at)
for i in range(start_at_2, min(start_at_2 + (num_cols * num_rows), len(dataset))):
classes = dataset.classes
image, segmentation, target = dataset[i]
labels = target["labels"]
boxes = target["boxes"]
if semantic_segmentation:
image = draw_image_with_semantic_segmentation(
image, dataset.metadata.image_size[0], dataset.metadata.image_size[1], segmentation, "Semantic Segmentation Preview", ""
)
if bounding_boxes_2d:
image = draw_image_with_boxes(
image, classes, labels, boxes, colors, "Bounding Boxes Preview", ""
)
container = cols[(i - (start_at_2 % num_cols)) % num_cols].beta_container()
container.image(image, caption=str(i), use_column_width=True)
if container.button(label="Expand image", key="exp"+str(i)):
container.write("IMAGE WAS CLICKED")
def zoom(index):
return None
def preview_app(args):
"""
Starts the dataset preview app.
:param args: Arguments for the app, such as dataset
:type args: Namespace
"""
dataset_dir = args.data
if dataset_dir is not None:
st.sidebar.title("Pyrception Dataset Preview")
preview_dataset(dataset_dir)
else:
raise ValueError("Please specify the path to the main dataset directory!")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("data", type=str)
args = parser.parse_args()
st.markdown('<style>button.css-9eqr5v{display: none}</style>', unsafe_allow_html=True)
preview_app(args)