Data schema

Schema for lab data management.

This package provides DataJoint schemas (tables) for accessing and managing laboratory data. The schemas are organized into modules by data type:

  • general - Core tables for files, subjects, sessions, datasets
  • procedures - Tables for experimental procedures and protocols
  • ephys - Tables for electrophysiology recordings and analysis
  • twophoton - Tables for two-photon microscopy data
  • onephoton - Tables for one-photon imaging (widefield and miniscope)
  • tasks - Tables for behavioral task data
  • video - Tables for video recordings
  • histology - Tables for histology and anatomy data

File

Bases: Manual

Table for tracking files stored in (s3 or local) storages.

This table stores metadata about files including their path, storage location, creation date, size and MD5 checksum. It provides methods for:

  • Deleting files from both the database and S3 storage (does not delete local files)
  • Downloading files from S3 to local storage (does not download local storagefiles)
  • Checking if files are archived in S3 Glacier storage (does not check local storage)

The table is used as a base class for AnalysisFile which handles analysis outputs.

Source code in labdata/schema/general.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
@globalschema 
class File(dj.Manual):
    '''Table for tracking files stored in (s3 or local) storages.

    This table stores metadata about files including their path, storage location,
    creation date, size and MD5 checksum. It provides methods for:

    - Deleting files from both the database and S3 storage (does not delete local files)
    - Downloading files from S3 to local storage (does not download local storagefiles)
    - Checking if files are archived in S3 Glacier storage (does not check local storage)

    The table is used as a base class for AnalysisFile which handles analysis outputs.
    '''
    definition = '''
    file_path                 : varchar(300)  # Path to the file
    storage = "{0}"           : varchar(12)   # storage name 
    ---
    file_datetime             : datetime      # date created
    file_size                 : double        # using double because int64 does not exist
    file_md5 = NULL           : varchar(32)   # md5 checksum
    '''.format(DEFAULT_RAW_STORAGE)
    storage = DEFAULT_RAW_STORAGE
    # Files get deleted from AWS if the user has permissions
    def delete(
            self,
            transaction = True,
            safemode  = None,
            force_parts = False):
        '''Delete files from both the database and S3 storage.

        Parameters
        ----------
        transaction : bool, optional
            Whether to perform deletion as a transaction, by default True
        safemode : bool, optional
            Whether to run in safe mode, by default None
        force_parts : bool, optional
            Whether to force deletion of parts, by default False

        Raises
        ------
        ValueError
            If files are deleted from database but not from S3
        '''

        from ..s3 import s3_delete_file
        from tqdm import tqdm
        filesdict = [f for f in self]
        super().delete(transaction = transaction,
                       safemode = safemode,
                       force_parts = force_parts)
        if len(self) == 0:
            files_not_deleted = []
            files_kept = []
            for s in tqdm(filesdict,desc = f'Deleting objects from s3 {"storage"}:'):
                fname = s["file_path"]
                storage = prefs['storage'][s['storage']]
                if storage['protocol'] == 's3':
                    try:
                        s3_delete_file(fname,
                                   storage = prefs['storage'][s['storage']],
                                   remove_versions = True)
                    except Exception as err:
                        print(f'Could not delete {fname}.')
                        files_not_deleted.append(fname)
                else:
                    print(f'Skipping {fname} because it is not in S3.')
                    files_kept.append(fname)
            if len(files_not_deleted):
                print('\n'.join(files_not_deleted))
                raise(ValueError('''

[Integrity error] Files were deleted from the database but not from AWS.

            Save this message and show it to your database ADMIN.

{0}

'''.format('\n'.join(files_not_deleted))))
            if len(files_kept):
                print('Files where not deleted from the local storage.')

    def check_if_files_local(self, local_paths = None):
        '''
        Checks if files are in a local path, searches accross all local paths

        Parameters
        ----------
        local_paths : list of str or Path, optional
            List of local paths to check for files, by default None uses paths in preferences

        Returns
        -------
        tuple
            Tuple of local file paths and missing files

        Raises
        ------
        ValueError
            If no files in the object
        '''
        if local_paths is None:
            local_paths = prefs['local_paths']
        if not len(self):
            raise(ValueError('No files to get.'))
        # this does not work with multiple storages
        files = [f['file_path'] for f in self]
        localfiles = [find_local_filepath(a, local_paths = local_paths) for a in files]
        # check if they exist and download only missing files.
        missingfiles = []
        for f in files:
            if not np.any([str(l).endswith(str(Path(f))) for l in localfiles]):
                missingfiles.append(f)
        return [l for l in localfiles if not l is None], missingfiles

    def get(self,local_paths = None, check_if_archived = True, restore=True, download = True,):
        '''Download files from S3 to local storage.

        Parameters
        ----------
        local_paths : list of str or Path, optional
            List of local paths to download files to, by default None uses paths in preferences
        check_if_archived : bool, optional
            Whether to check if files are in Glacier storage, by default True
        restore : bool, optional
            Whether to restore archived files, by default True
        download : bool, optional
            Whether to actually download the files, by default True

        Returns
        -------
        list
            List of local file paths that were downloaded

        Raises
        ------
        ValueError
            If no files are found to download
        '''
        if local_paths is None:
            local_paths = prefs['local_paths']
        if not len(self):
            raise(ValueError('No files to get.'))

        localfiles, remotefiles = self.check_if_files_local(local_paths = local_paths)
        storage = [f['storage'] for f in self][0]
        remotefiles = self & [dict(file_path = f) for f in remotefiles]
        if len(remotefiles):
            if prefs['storage'][storage]['protocol'] == 's3':
                if check_if_archived:
                    # TODO: add to the preference file to not restore by default.
                    self.check_if_files_archived(files = remotefiles, restore = restore)
                if download:
                    print(f'Downloading {len(remotefiles)} files from S3 [{storage}].')
                    remotefiles = [r['file_path'] for r in remotefiles]
                    dstfiles = [Path(local_paths[0])/f for f in remotefiles]  # place to store file.
                    from ..s3 import copy_from_s3
                    copy_from_s3(remotefiles,dstfiles,storage_name = storage)
                    localfiles, _ = self.check_if_files_local(local_paths = local_paths)
            elif prefs['storage'][storage]['protocol'] == 'local':
                # TODO, copy files from local storage to the first local path.
                print('Downloading from local storage is not implemented, use local_paths.')
        return localfiles

    def check_if_files_archived(self, files = None, restore = True, suppress_error = False):
        '''Check if files are archived in S3 Glacier storage.

        Parameters
        ----------
        restore : bool, optional
            Whether to initiate restore of archived files, by default True
        suppress_error : bool, optional
            Whether to suppress error if files are being restored, by default False

        Returns
        -------
        bool
            True if files are archived, False otherwise

        Raises
        ------
        OSError
            If storage is not in preferences or if files are being restored
        '''
        files_restoring = []
        import boto3
        if files is None:
            files = self
        for f in files:
            # check if files are archived
            # TODO: run this in parallel because it takes a while.
            if not f['storage'] in prefs['storage'].keys():
                raise(OSError(f"Store {f['storage']} is not in the preference file."))
            store = prefs['storage'][f['storage']]

            s3 = boto3.resource('s3',aws_access_key_id = store['access_key'],
                            aws_secret_access_key = store['secret_key'])

            obj = s3.Object(bucket_name = store['bucket'],
                            key = f['file_path'])

            if not obj.archive_status is None and 'ARCHIVE' in obj.archive_status:
                if obj.restore is None:
                    if restore:
                        resp = obj.restore_object(RestoreRequest = {})
                    files_restoring.append(f['file_path'])
                elif 'true' in obj.restore:
                    files_restoring.append(f['file_path'])
        if len(files_restoring):
            import warnings
            warnings.warn(f"Files are being restored [{files_restoring}]")
            if not suppress_error:
                raise(OSError(f"Files are being restored [{files_restoring}]"))
            return True # files are in arquive
        return False # files are not in archive

check_if_files_archived(files=None, restore=True, suppress_error=False)

Check if files are archived in S3 Glacier storage.

Parameters:
  • restore (bool, default: True ) –

    Whether to initiate restore of archived files, by default True

  • suppress_error (bool, default: False ) –

    Whether to suppress error if files are being restored, by default False

Returns:
  • bool

    True if files are archived, False otherwise

Raises:
  • OSError

    If storage is not in preferences or if files are being restored

Source code in labdata/schema/general.py
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def check_if_files_archived(self, files = None, restore = True, suppress_error = False):
    '''Check if files are archived in S3 Glacier storage.

    Parameters
    ----------
    restore : bool, optional
        Whether to initiate restore of archived files, by default True
    suppress_error : bool, optional
        Whether to suppress error if files are being restored, by default False

    Returns
    -------
    bool
        True if files are archived, False otherwise

    Raises
    ------
    OSError
        If storage is not in preferences or if files are being restored
    '''
    files_restoring = []
    import boto3
    if files is None:
        files = self
    for f in files:
        # check if files are archived
        # TODO: run this in parallel because it takes a while.
        if not f['storage'] in prefs['storage'].keys():
            raise(OSError(f"Store {f['storage']} is not in the preference file."))
        store = prefs['storage'][f['storage']]

        s3 = boto3.resource('s3',aws_access_key_id = store['access_key'],
                        aws_secret_access_key = store['secret_key'])

        obj = s3.Object(bucket_name = store['bucket'],
                        key = f['file_path'])

        if not obj.archive_status is None and 'ARCHIVE' in obj.archive_status:
            if obj.restore is None:
                if restore:
                    resp = obj.restore_object(RestoreRequest = {})
                files_restoring.append(f['file_path'])
            elif 'true' in obj.restore:
                files_restoring.append(f['file_path'])
    if len(files_restoring):
        import warnings
        warnings.warn(f"Files are being restored [{files_restoring}]")
        if not suppress_error:
            raise(OSError(f"Files are being restored [{files_restoring}]"))
        return True # files are in arquive
    return False # files are not in archive

check_if_files_local(local_paths=None)

Checks if files are in a local path, searches accross all local paths

Parameters:
  • local_paths (list of str or Path, default: None ) –

    List of local paths to check for files, by default None uses paths in preferences

Returns:
  • tuple

    Tuple of local file paths and missing files

Raises:
  • ValueError

    If no files in the object

Source code in labdata/schema/general.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
def check_if_files_local(self, local_paths = None):
    '''
    Checks if files are in a local path, searches accross all local paths

    Parameters
    ----------
    local_paths : list of str or Path, optional
        List of local paths to check for files, by default None uses paths in preferences

    Returns
    -------
    tuple
        Tuple of local file paths and missing files

    Raises
    ------
    ValueError
        If no files in the object
    '''
    if local_paths is None:
        local_paths = prefs['local_paths']
    if not len(self):
        raise(ValueError('No files to get.'))
    # this does not work with multiple storages
    files = [f['file_path'] for f in self]
    localfiles = [find_local_filepath(a, local_paths = local_paths) for a in files]
    # check if they exist and download only missing files.
    missingfiles = []
    for f in files:
        if not np.any([str(l).endswith(str(Path(f))) for l in localfiles]):
            missingfiles.append(f)
    return [l for l in localfiles if not l is None], missingfiles

delete(transaction=True, safemode=None, force_parts=False)

Delete files from both the database and S3 storage.

Parameters:
  • transaction (bool, default: True ) –

    Whether to perform deletion as a transaction, by default True

  • safemode (bool, default: None ) –

    Whether to run in safe mode, by default None

  • force_parts (bool, default: False ) –

    Whether to force deletion of parts, by default False

Raises:
  • ValueError

    If files are deleted from database but not from S3

Source code in labdata/schema/general.py
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    def delete(
            self,
            transaction = True,
            safemode  = None,
            force_parts = False):
        '''Delete files from both the database and S3 storage.

        Parameters
        ----------
        transaction : bool, optional
            Whether to perform deletion as a transaction, by default True
        safemode : bool, optional
            Whether to run in safe mode, by default None
        force_parts : bool, optional
            Whether to force deletion of parts, by default False

        Raises
        ------
        ValueError
            If files are deleted from database but not from S3
        '''

        from ..s3 import s3_delete_file
        from tqdm import tqdm
        filesdict = [f for f in self]
        super().delete(transaction = transaction,
                       safemode = safemode,
                       force_parts = force_parts)
        if len(self) == 0:
            files_not_deleted = []
            files_kept = []
            for s in tqdm(filesdict,desc = f'Deleting objects from s3 {"storage"}:'):
                fname = s["file_path"]
                storage = prefs['storage'][s['storage']]
                if storage['protocol'] == 's3':
                    try:
                        s3_delete_file(fname,
                                   storage = prefs['storage'][s['storage']],
                                   remove_versions = True)
                    except Exception as err:
                        print(f'Could not delete {fname}.')
                        files_not_deleted.append(fname)
                else:
                    print(f'Skipping {fname} because it is not in S3.')
                    files_kept.append(fname)
            if len(files_not_deleted):
                print('\n'.join(files_not_deleted))
                raise(ValueError('''

[Integrity error] Files were deleted from the database but not from AWS.

            Save this message and show it to your database ADMIN.

{0}

'''.format('\n'.join(files_not_deleted))))
            if len(files_kept):
                print('Files where not deleted from the local storage.')

get(local_paths=None, check_if_archived=True, restore=True, download=True)

Download files from S3 to local storage.

Parameters:
  • local_paths (list of str or Path, default: None ) –

    List of local paths to download files to, by default None uses paths in preferences

  • check_if_archived (bool, default: True ) –

    Whether to check if files are in Glacier storage, by default True

  • restore (bool, default: True ) –

    Whether to restore archived files, by default True

  • download (bool, default: True ) –

    Whether to actually download the files, by default True

Returns:
  • list

    List of local file paths that were downloaded

Raises:
  • ValueError

    If no files are found to download

Source code in labdata/schema/general.py
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def get(self,local_paths = None, check_if_archived = True, restore=True, download = True,):
    '''Download files from S3 to local storage.

    Parameters
    ----------
    local_paths : list of str or Path, optional
        List of local paths to download files to, by default None uses paths in preferences
    check_if_archived : bool, optional
        Whether to check if files are in Glacier storage, by default True
    restore : bool, optional
        Whether to restore archived files, by default True
    download : bool, optional
        Whether to actually download the files, by default True

    Returns
    -------
    list
        List of local file paths that were downloaded

    Raises
    ------
    ValueError
        If no files are found to download
    '''
    if local_paths is None:
        local_paths = prefs['local_paths']
    if not len(self):
        raise(ValueError('No files to get.'))

    localfiles, remotefiles = self.check_if_files_local(local_paths = local_paths)
    storage = [f['storage'] for f in self][0]
    remotefiles = self & [dict(file_path = f) for f in remotefiles]
    if len(remotefiles):
        if prefs['storage'][storage]['protocol'] == 's3':
            if check_if_archived:
                # TODO: add to the preference file to not restore by default.
                self.check_if_files_archived(files = remotefiles, restore = restore)
            if download:
                print(f'Downloading {len(remotefiles)} files from S3 [{storage}].')
                remotefiles = [r['file_path'] for r in remotefiles]
                dstfiles = [Path(local_paths[0])/f for f in remotefiles]  # place to store file.
                from ..s3 import copy_from_s3
                copy_from_s3(remotefiles,dstfiles,storage_name = storage)
                localfiles, _ = self.check_if_files_local(local_paths = local_paths)
        elif prefs['storage'][storage]['protocol'] == 'local':
            # TODO, copy files from local storage to the first local path.
            print('Downloading from local storage is not implemented, use local_paths.')
    return localfiles

AnalysisFile

Bases: File

Source code in labdata/schema/general.py
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
@globalschema # users need permission to delete from this table
class AnalysisFile(File):
    definition = '''
    file_path                 : varchar(300)  # Path to the file
    storage = "{0}"           : varchar(12)   # storage name 
    ---
    file_datetime             : datetime      # date created
    file_size                 : double        # using double because int64 does not exist
    file_md5 = NULL           : varchar(32)   # md5 checksum
    '''.format(DEFAULT_ANALYSIS_STORAGE)
    storage = DEFAULT_ANALYSIS_STORAGE
    # All users with permission to run analysis should also have permission to add and remove files from the analysis bucket in AWS
    def upload_files(self,src,dataset, force = True):
        '''
        Upload a file to the AWS analysis bucket.

        src is a list of file paths

        '''

        dst = self.generate_filepaths(src, dataset)
        for d in dst:
            if len(AnalysisFile() & dict(file_path = d)) > 0:
                if not force:
                    ValueError(f'File is already in database, delete it to re-upload {d}.')
                else:
                    (AnalysisFile() & dict(file_path = d)).delete(safemode = False) # just the table
        assert self.storage in prefs['storage'].keys(),ValueError(
            'Specify an {self.storage} bucket in preferences["storage"].')
        # compute checksum and sizes
        md5 = compute_md5s(src)
        dates = [datetime.utcfromtimestamp(Path(f).stat().st_mtime) for f in src]
        sizes = [Path(f).stat().st_size for f in src]
        # upload to s3
        from ..s3 import copy_to_s3
        copy_to_s3(src, dst, md5_checksum=None, storage_name = self.storage)
        # insert in AnalysisFile if all went well
        self.insert([dict(file_path = f,
                          storage = self.storage,
                          file_datetime = d,
                          file_md5 = m,
                          file_size = s) for f,d,s,m in zip(dst,dates,sizes,md5)])
        return [dict(file_path = f,storage = self.storage) for f in dst]

    def generate_filepaths(self,src, dataset):
        assert 'subject_name' in dataset.keys(), ValueError('dataset must have subject_name')
        assert 'session_name' in dataset.keys(), ValueError('dataset must have session_name')
        assert 'dataset_name' in dataset.keys(), ValueError('dataset must have dataset_name')
        destpath = '{subject_name}/{session_name}/{dataset_name}/'.format(**dataset)
        if not schema_project is None:
            destpath = f'{schema_project}/{destpath}' # add the project name so these files can have the same name across projects or be shared easily.
        return [destpath+Path(k).name for k in src]

upload_files(src, dataset, force=True)

Upload a file to the AWS analysis bucket.

src is a list of file paths

Source code in labdata/schema/general.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def upload_files(self,src,dataset, force = True):
    '''
    Upload a file to the AWS analysis bucket.

    src is a list of file paths

    '''

    dst = self.generate_filepaths(src, dataset)
    for d in dst:
        if len(AnalysisFile() & dict(file_path = d)) > 0:
            if not force:
                ValueError(f'File is already in database, delete it to re-upload {d}.')
            else:
                (AnalysisFile() & dict(file_path = d)).delete(safemode = False) # just the table
    assert self.storage in prefs['storage'].keys(),ValueError(
        'Specify an {self.storage} bucket in preferences["storage"].')
    # compute checksum and sizes
    md5 = compute_md5s(src)
    dates = [datetime.utcfromtimestamp(Path(f).stat().st_mtime) for f in src]
    sizes = [Path(f).stat().st_size for f in src]
    # upload to s3
    from ..s3 import copy_to_s3
    copy_to_s3(src, dst, md5_checksum=None, storage_name = self.storage)
    # insert in AnalysisFile if all went well
    self.insert([dict(file_path = f,
                      storage = self.storage,
                      file_datetime = d,
                      file_md5 = m,
                      file_size = s) for f,d,s,m in zip(dst,dates,sizes,md5)])
    return [dict(file_path = f,storage = self.storage) for f in dst]

Procedure

Bases: Manual

Table for tracking experimental procedures performed on subjects.

Each procedure entry includes: - Subject - ProcedureType - Date and time - Lab member who performed it - Optional metadata: weight, and notes

Source code in labdata/schema/procedures.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataschema
class Procedure(dj.Manual):
    '''Table for tracking experimental procedures performed on subjects.

    Each procedure entry includes:
    - Subject
    - ProcedureType
    - Date and time
    - Lab member who performed it
    - Optional metadata: weight, and notes
    '''
    definition = """
    -> Subject
    -> ProcedureType
    procedure_datetime            : datetime
    ---
    -> LabMember
    procedure_metadata = NULL     : longblob   
    -> [nullable] Weighing
    -> [nullable] Note
    """

ProcedureType

Bases: Lookup

Table defining types of experimental procedures.

This lookup table enumerates the different types of experimental procedures, including: - Surgical procedures (surgery, implants, craniotomy) - Behavioral procedures (handling, training) - Other manipulations (injections)

The procedure types are used by the Procedure table to categorize and track all procedures performed.

Source code in labdata/schema/procedures.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@dataschema
class ProcedureType(dj.Lookup):
    '''Table defining types of experimental procedures.

    This lookup table enumerates the different types of experimental procedures, including:
    - Surgical procedures (surgery, implants, craniotomy)
    - Behavioral procedures (handling, training) 
    - Other manipulations (injections)

    The procedure types are used by the Procedure table to categorize and track all
    procedures performed.
    '''
    definition = """
    procedure_type : varchar(52)       #  Defines procedures that are not an experimental session
    """
    contents = zip(['surgery',
                    'chronic implant',
                    'chronic explant', 
                    'injection',
                    'window implant',
                    'window replacement',
                    'handling',
                    'training',
                    'craniotomy'])

Subject

Bases: Manual

Experimental subject.

Source code in labdata/schema/general.py
481
482
483
484
485
486
487
488
489
490
491
@dataschema
class Subject(dj.Manual):
    ''' Experimental subject.'''
    definition = """
    subject_name               : varchar(20)          # unique mouse id
    ---
    subject_dob                : date                 # mouse date of birth
    subject_sex                : enum('M', 'F', 'U')  # sex of mouse - Male, Female, or Unknown
    -> Strain
    -> LabMember
    """

Session

Bases: Manual

Source code in labdata/schema/general.py
517
518
519
520
521
522
523
524
525
@dataschema
class Session(dj.Manual):
    definition = """
    -> Subject
    session_name             : varchar(54)     # session identifier
    ---
    session_datetime         : datetime        # experiment date
    -> [nullable] LabMember.proj(experimenter = 'user_name') 
    """

Dataset

Bases: Manual

Source code in labdata/schema/general.py
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
@dataschema
class Dataset(dj.Manual):
    definition = """
    -> Subject
    -> Session
    dataset_name             : varchar(128)    
    ---
    -> [nullable] DatasetType
    -> [nullable] Setup
    -> [nullable] Note
    """
    class DataFiles(dj.Part):  # the files that were acquired on that dataset.
        definition = '''
        -> master
        -> File
        '''

DatasetEvents

Bases: Imported

Source code in labdata/schema/general.py
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
@dataschema
class DatasetEvents(dj.Imported):
    definition = '''
    -> Dataset
    stream_name                       : varchar(54)   # which clock is used e.g. btss, nidq, bpod, imecX
    ---
    stream_time = NULL                 : longblob     # for e.g. the analog channels
    '''
    class Digital(dj.Part):
        definition = '''
        -> master
        event_name                    : varchar(54)
        ---
        event_timestamps = NULL       : longblob  # timestamps of the events
        event_values = NULL           : longblob  # event value or count
        '''
        projkeys = ['subject_name','session_name','dataset_name','stream_name','event_name']

        def fetch_synced(self, force = False,method = 'cubic-spline'):
            ''' 
            Returned events already synchronized between data streams, following the StreamSync table.
            method: cubic-spline, piecewise-linear
            '''
            keys = [dict(subject_name = s["subject_name"],
                         session_name = s["session_name"],
                         dataset_name = s["dataset_name"],
                         stream_name = s["stream_name"]) for s in self]
            evnts = []
            streams = (StreamSync() & keys)
            if not len(streams):
                from warnings import warn
                warn(f'There are no StreamSync for events {self.proj()}. This will return only the clock stream.')
            else:    
                for s in streams:
                    evs = (self & dict(stream_name = s["stream_name"])).fetch(as_dict = True)
                    func = (StreamSync() & s).apply(None, force = force,method = method)
                    for evnt in evs:
                        if not evnt['event_timestamps'] is None:
                            evnt['event_timestamps'] = func(evnt['event_timestamps'])
                            evnts.append(evnt)
            # add the events from the clock stream
            if len(streams):
                evs = (self & dict(stream_name = streams.clock_stream())).fetch(as_dict = True)
            else:
                evs = self
            for evnt in evs:
                evnts.append(evnt)
            return evnts

        def plot_synced(self, stream_colors = 'krbgyb', overlay_original = False, lw = 1,force = True):
            ''' Plot DatasetEvents.Digital.'''
            evnts = self.fetch_synced(force = force)
            ustreams = [n for n in np.unique([e['stream_name'] for e in evnts])]

            import pylab as plt
            caption = []
            ticks = []
            lns = []
            for i,e in enumerate(evnts):
                ln = plt.vlines(e['event_timestamps'],i,i+0.7,
                                color = stream_colors[np.mod(ustreams.index(e['stream_name']),len(stream_colors))],
                                lw = lw)
                if overlay_original:
                    ee = (DatasetEvents.Digital() & {k:e[k] for k in self.projkeys}).fetch('event_timestamps')[0]
                    plt.vlines(ee,i+0.6,i+0.9,color='gray',lw = lw)
                lns.append(ln)
                ticks.append(i+0.35)
                caption.append(f'{e["stream_name"]}_{e["event_name"]}')
            plt.yticks(ticks,caption);
            return lns

    class AnalogChannel(dj.Part):
        definition = '''
        -> master
        channel_name                 : varchar(54)
        ---
        channel_values = NULL        : longblob  # analog values for channel
        '''

Digital

Bases: Part

Source code in labdata/schema/general.py
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
class Digital(dj.Part):
    definition = '''
    -> master
    event_name                    : varchar(54)
    ---
    event_timestamps = NULL       : longblob  # timestamps of the events
    event_values = NULL           : longblob  # event value or count
    '''
    projkeys = ['subject_name','session_name','dataset_name','stream_name','event_name']

    def fetch_synced(self, force = False,method = 'cubic-spline'):
        ''' 
        Returned events already synchronized between data streams, following the StreamSync table.
        method: cubic-spline, piecewise-linear
        '''
        keys = [dict(subject_name = s["subject_name"],
                     session_name = s["session_name"],
                     dataset_name = s["dataset_name"],
                     stream_name = s["stream_name"]) for s in self]
        evnts = []
        streams = (StreamSync() & keys)
        if not len(streams):
            from warnings import warn
            warn(f'There are no StreamSync for events {self.proj()}. This will return only the clock stream.')
        else:    
            for s in streams:
                evs = (self & dict(stream_name = s["stream_name"])).fetch(as_dict = True)
                func = (StreamSync() & s).apply(None, force = force,method = method)
                for evnt in evs:
                    if not evnt['event_timestamps'] is None:
                        evnt['event_timestamps'] = func(evnt['event_timestamps'])
                        evnts.append(evnt)
        # add the events from the clock stream
        if len(streams):
            evs = (self & dict(stream_name = streams.clock_stream())).fetch(as_dict = True)
        else:
            evs = self
        for evnt in evs:
            evnts.append(evnt)
        return evnts

    def plot_synced(self, stream_colors = 'krbgyb', overlay_original = False, lw = 1,force = True):
        ''' Plot DatasetEvents.Digital.'''
        evnts = self.fetch_synced(force = force)
        ustreams = [n for n in np.unique([e['stream_name'] for e in evnts])]

        import pylab as plt
        caption = []
        ticks = []
        lns = []
        for i,e in enumerate(evnts):
            ln = plt.vlines(e['event_timestamps'],i,i+0.7,
                            color = stream_colors[np.mod(ustreams.index(e['stream_name']),len(stream_colors))],
                            lw = lw)
            if overlay_original:
                ee = (DatasetEvents.Digital() & {k:e[k] for k in self.projkeys}).fetch('event_timestamps')[0]
                plt.vlines(ee,i+0.6,i+0.9,color='gray',lw = lw)
            lns.append(ln)
            ticks.append(i+0.35)
            caption.append(f'{e["stream_name"]}_{e["event_name"]}')
        plt.yticks(ticks,caption);
        return lns

fetch_synced(force=False, method='cubic-spline')

Returned events already synchronized between data streams, following the StreamSync table. method: cubic-spline, piecewise-linear

Source code in labdata/schema/general.py
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
def fetch_synced(self, force = False,method = 'cubic-spline'):
    ''' 
    Returned events already synchronized between data streams, following the StreamSync table.
    method: cubic-spline, piecewise-linear
    '''
    keys = [dict(subject_name = s["subject_name"],
                 session_name = s["session_name"],
                 dataset_name = s["dataset_name"],
                 stream_name = s["stream_name"]) for s in self]
    evnts = []
    streams = (StreamSync() & keys)
    if not len(streams):
        from warnings import warn
        warn(f'There are no StreamSync for events {self.proj()}. This will return only the clock stream.')
    else:    
        for s in streams:
            evs = (self & dict(stream_name = s["stream_name"])).fetch(as_dict = True)
            func = (StreamSync() & s).apply(None, force = force,method = method)
            for evnt in evs:
                if not evnt['event_timestamps'] is None:
                    evnt['event_timestamps'] = func(evnt['event_timestamps'])
                    evnts.append(evnt)
    # add the events from the clock stream
    if len(streams):
        evs = (self & dict(stream_name = streams.clock_stream())).fetch(as_dict = True)
    else:
        evs = self
    for evnt in evs:
        evnts.append(evnt)
    return evnts

plot_synced(stream_colors='krbgyb', overlay_original=False, lw=1, force=True)

Plot DatasetEvents.Digital.

Source code in labdata/schema/general.py
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
def plot_synced(self, stream_colors = 'krbgyb', overlay_original = False, lw = 1,force = True):
    ''' Plot DatasetEvents.Digital.'''
    evnts = self.fetch_synced(force = force)
    ustreams = [n for n in np.unique([e['stream_name'] for e in evnts])]

    import pylab as plt
    caption = []
    ticks = []
    lns = []
    for i,e in enumerate(evnts):
        ln = plt.vlines(e['event_timestamps'],i,i+0.7,
                        color = stream_colors[np.mod(ustreams.index(e['stream_name']),len(stream_colors))],
                        lw = lw)
        if overlay_original:
            ee = (DatasetEvents.Digital() & {k:e[k] for k in self.projkeys}).fetch('event_timestamps')[0]
            plt.vlines(ee,i+0.6,i+0.9,color='gray',lw = lw)
        lns.append(ln)
        ticks.append(i+0.35)
        caption.append(f'{e["stream_name"]}_{e["event_name"]}')
    plt.yticks(ticks,caption);
    return lns

StreamSync

Bases: Manual

Source code in labdata/schema/general.py
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
@analysisschema
class StreamSync(dj.Manual):
    definition = '''
    -> Dataset
    -> DatasetEvents.Digital
    -> DatasetEvents.Digital.proj(clock_stream='stream_name',clock_stream_event='event_name',clock_dataset = 'dataset_name')
    '''
    def get_interp_data(self,force = False, warn = True, allowed_offset = 2):
        ''' Force will attempt to remove events from the longest stream so the streams are matched. '''

        assert len(self)==1, ValueError(f"This function only takes one element at a time not {len(self)}.")
        s = self.fetch1()
        clock = (DatasetEvents.Digital() & dict(subject_name = s['subject_name'],
                                                session_name = s['session_name'],
                                                dataset_name = s['clock_dataset'],
                                                stream_name = s['clock_stream'],
                                                event_name = s['clock_stream_event'])).fetch1()
        clock_onsets = np.array(clock['event_timestamps'])
        sync = (DatasetEvents.Digital() & dict(subject_name = s['subject_name'],
                                               session_name = s['session_name'],
                                               dataset_name = s['dataset_name'],
                                               stream_name = s['stream_name'],
                                               event_name = s['event_name'])).fetch1()
        sync_onsets = np.array(sync['event_timestamps'])
        if ((len(sync_onsets)>=(len(clock_onsets)//2)-allowed_offset) and 
            (len(sync_onsets)<=(len(clock_onsets)//2)+allowed_offset)): # in case clock has both onsets and offsets
            if clock['event_values'] is None:
                clock_onsets = clock_onsets[::2]
            else:
                clock_onsets = clock_onsets[np.array(clock['event_values']) == 1]
        if ((len(clock_onsets)>=(len(sync_onsets)//2)-allowed_offset) and 
            (len(clock_onsets)<=(len(sync_onsets)//2)+allowed_offset)): # in case sync has both onsets and offsets
            if sync['event_values'] is None:
                sync_onsets = sync_onsets[::2]
            else:
                sync_onsets = sync_onsets[np.array(sync['event_values'])==1]
        if warn: 
            if not len(sync_onsets) == len(clock_onsets):
                import warnings
                warnings.warn(f"There is a potential issue with the syncronization of sessions: {s['subject_name']} {s['session_name']}", UserWarning)
                print(f"    - stream {s['clock_stream']} channel {s['clock_stream_event']} {len(clock_onsets)}")
                print(f"    - stream {s['stream_name']} channel {s['event_name']} {len(sync_onsets)}")
        if not force: # by default this is set to false.
            assert len(clock_onsets) == len(sync_onsets), ValueError(f'\n\n Length of the clock and sync not the same? \n\n {self}')
        N = np.min([len(sync_onsets),len(clock_onsets)])
        return sync_onsets[:N],clock_onsets[:N]

    def apply(self, values, sync_onsets = None, warn = True, clock_onsets = None, force = False, method = 'cubic-spline'):
        '''
        Returns synchronized signals according a sync pulse shared from a clock.
        "clock" is main, "sync" is the same acquisition system as "values"
        '''
        if sync_onsets is None or clock_onsets is None:
            sync_onsets, clock_onsets = self.get_interp_data(force = force, warn = warn)
        if method == 'cubic-spline':
            # cubic spline interpolation handles the extrapolated points better
            from scipy.interpolate import CubicSpline
            func = CubicSpline(sync_onsets,clock_onsets)
        elif method == 'piecewise-linear':
            # linear interpolation to get the time of events syncronized across streams
            func = lambda x: np.interp(x,sync_onsets,clock_onsets)
        else:
            raise(ValueError(f'Unknown interpolation method {method}'))
        if values is None: # return a function if no values passed
            return func
        return func(values)

    def clock_stream(self):
        '''
        Returns the name of the clock stream(s)
        '''
        clkstreams = np.unique(self.fetch('clock_stream'))
        if len(clkstreams) == 1:
            return clkstreams[0]
        return clkstreams

apply(values, sync_onsets=None, warn=True, clock_onsets=None, force=False, method='cubic-spline')

Returns synchronized signals according a sync pulse shared from a clock. "clock" is main, "sync" is the same acquisition system as "values"

Source code in labdata/schema/general.py
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
def apply(self, values, sync_onsets = None, warn = True, clock_onsets = None, force = False, method = 'cubic-spline'):
    '''
    Returns synchronized signals according a sync pulse shared from a clock.
    "clock" is main, "sync" is the same acquisition system as "values"
    '''
    if sync_onsets is None or clock_onsets is None:
        sync_onsets, clock_onsets = self.get_interp_data(force = force, warn = warn)
    if method == 'cubic-spline':
        # cubic spline interpolation handles the extrapolated points better
        from scipy.interpolate import CubicSpline
        func = CubicSpline(sync_onsets,clock_onsets)
    elif method == 'piecewise-linear':
        # linear interpolation to get the time of events syncronized across streams
        func = lambda x: np.interp(x,sync_onsets,clock_onsets)
    else:
        raise(ValueError(f'Unknown interpolation method {method}'))
    if values is None: # return a function if no values passed
        return func
    return func(values)

clock_stream()

Returns the name of the clock stream(s)

Source code in labdata/schema/general.py
698
699
700
701
702
703
704
705
def clock_stream(self):
    '''
    Returns the name of the clock stream(s)
    '''
    clkstreams = np.unique(self.fetch('clock_stream'))
    if len(clkstreams) == 1:
        return clkstreams[0]
    return clkstreams

get_interp_data(force=False, warn=True, allowed_offset=2)

Force will attempt to remove events from the longest stream so the streams are matched.

Source code in labdata/schema/general.py
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
def get_interp_data(self,force = False, warn = True, allowed_offset = 2):
    ''' Force will attempt to remove events from the longest stream so the streams are matched. '''

    assert len(self)==1, ValueError(f"This function only takes one element at a time not {len(self)}.")
    s = self.fetch1()
    clock = (DatasetEvents.Digital() & dict(subject_name = s['subject_name'],
                                            session_name = s['session_name'],
                                            dataset_name = s['clock_dataset'],
                                            stream_name = s['clock_stream'],
                                            event_name = s['clock_stream_event'])).fetch1()
    clock_onsets = np.array(clock['event_timestamps'])
    sync = (DatasetEvents.Digital() & dict(subject_name = s['subject_name'],
                                           session_name = s['session_name'],
                                           dataset_name = s['dataset_name'],
                                           stream_name = s['stream_name'],
                                           event_name = s['event_name'])).fetch1()
    sync_onsets = np.array(sync['event_timestamps'])
    if ((len(sync_onsets)>=(len(clock_onsets)//2)-allowed_offset) and 
        (len(sync_onsets)<=(len(clock_onsets)//2)+allowed_offset)): # in case clock has both onsets and offsets
        if clock['event_values'] is None:
            clock_onsets = clock_onsets[::2]
        else:
            clock_onsets = clock_onsets[np.array(clock['event_values']) == 1]
    if ((len(clock_onsets)>=(len(sync_onsets)//2)-allowed_offset) and 
        (len(clock_onsets)<=(len(sync_onsets)//2)+allowed_offset)): # in case sync has both onsets and offsets
        if sync['event_values'] is None:
            sync_onsets = sync_onsets[::2]
        else:
            sync_onsets = sync_onsets[np.array(sync['event_values'])==1]
    if warn: 
        if not len(sync_onsets) == len(clock_onsets):
            import warnings
            warnings.warn(f"There is a potential issue with the syncronization of sessions: {s['subject_name']} {s['session_name']}", UserWarning)
            print(f"    - stream {s['clock_stream']} channel {s['clock_stream_event']} {len(clock_onsets)}")
            print(f"    - stream {s['stream_name']} channel {s['event_name']} {len(sync_onsets)}")
    if not force: # by default this is set to false.
        assert len(clock_onsets) == len(sync_onsets), ValueError(f'\n\n Length of the clock and sync not the same? \n\n {self}')
    N = np.min([len(sync_onsets),len(clock_onsets)])
    return sync_onsets[:N],clock_onsets[:N]

Watering

Bases: Manual

Table for tracking water administration to subjects.

This table records water consumed, including: - Subject receiving water - Date and time - Volume of water, in microliters

Source code in labdata/schema/procedures.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@dataschema
class Watering(dj.Manual):
    '''Table for tracking water administration to subjects.

    This table records water consumed, including:
    - Subject receiving water
    - Date and time
    - Volume of water, in microliters

    '''
    definition = """
    -> Subject
    watering_datetime : datetime
    ---
    water_volume : float  # (uL)
    """

Weighing

Bases: Manual

Table for tracking subject weights.

This table stores weight measurements for experimental subjects. Each entry includes: - Subject - Date and time of weighing - Weight in grams

Source code in labdata/schema/procedures.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
@dataschema
class Weighing(dj.Manual):
    '''Table for tracking subject weights.

    This table stores weight measurements for experimental subjects. Each entry includes:
    - Subject
    - Date and time of weighing
    - Weight in grams

    '''
    definition = """
    -> Subject
    weighing_datetime : datetime
    ---
    weight : float  # (g)
    """

DecisionTask

Bases: Imported

Table for behavioral decision task data.

This table serves as a general schema for decision-making behavioral tasks, abstracting common features that can be inherited by specific task tables defined in plugins. Each entry includes: - Total trial counts (assisted, self-performed, initiated, with choice) - Performance metrics (rewarded, punished trials) - Optional reference to water intake during session

The table includes a Part table (TrialSet) that stores detailed sets of trials within a session, dependent on the modality or condition and includes including: - Trial conditions and modalities - Performance metrics per condition - Timing data (initiation times, reaction times) - Response and subject feedback values - Stimulus parameters (intensity, block)

This data can be used to compute: - Psychometric curves - Learning curves - Reaction time distributions - Choice biases and strategies - ...

The schema is designed to be flexible and is meant to be populated by specific task tables (defined as plugins) while maintaining a consistent interface for analysis and visualization.

Source code in labdata/schema/tasks.py
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
@analysisschema
class DecisionTask(dj.Imported): # imported because if comes from data but there is no 'make'
    '''Table for behavioral decision task data.

    This table serves as a general schema for decision-making behavioral tasks,
    abstracting common features that can be inherited by specific task tables
    defined in plugins. Each entry includes:
    - Total trial counts (assisted, self-performed, initiated, with choice)
    - Performance metrics (rewarded, punished trials)
    - Optional reference to water intake during session

    The table includes a Part table (TrialSet) that stores detailed sets of trials within a session, 
    dependent on the modality or condition and includes including:
    - Trial conditions and modalities
    - Performance metrics per condition
    - Timing data (initiation times, reaction times) 
    - Response and subject feedback values
    - Stimulus parameters (intensity, block)

    This data can be used to compute:
    - Psychometric curves
    - Learning curves
    - Reaction time distributions
    - Choice biases and strategies
    - ...

    The schema is designed to be flexible and is meant to be populated by specific task tables (defined as plugins) 
    while maintaining a consistent interface for analysis and visualization.
    '''
    definition = '''
    -> Dataset
    ---
    n_total_trials              : int            # number of trials in the session
    n_total_assisted = NULL     : int            # number of assisted trials in the session
    n_total_performed = NULL    : int            # number of self-performed trials
    n_total_initiated = NULL    : int            # number of initiated trials
    n_total_with_choice = NULL  : int            # number of self-initiated with choice
    n_total_rewarded = NULL     : int            # number of rewarded trials
    n_total_punished = NULL     : int            # number of punished trials
    -> [nullable] Watering                       # water intake during the session (ml) 
    '''

    class TrialSet(dj.Part):
        definition = '''
        -> master
        trialset_description     : varchar(54) # e.g. trial modality, unique condition
        ---
        n_trials                 : int         # total number of trials
        n_performed              : int         # number of self-performed trials
        n_with_choice            : int         # number of self-initiated trials with choice 
        n_correct                : int         # number of correct trials
        performance_easy = NULL  : float       # performance on easy trials
        performance = NULL       : float       # performance on all trials
        trial_num                : longblob    # trial number because TrialSets can be intertwined
        initiation_times = NULL  : longblob    # time between trial start and stim onset
        assisted = NULL          : longblob    # wether the trial was assisted
        response_values = NULL   : longblob    # left=1;no response=0; right=-1        
        correct_values = NULL    : longblob    # correct = 1; no_response  = NaN; wrong = 0        
        intensity_values = NULL  : longblob    # value of the stim (left-right)
        reaction_times = NULL    : longblob    # between onset of the response period and reporting  
        block_values = NULL      : longblob    # block number for each trial
        '''

PoseEstimationLabelSet

Bases: Manual

Source code in labdata/schema/video.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
@analysisschema
class PoseEstimationLabelSet(dj.Manual):
    definition = '''
    pose_label_set_num    : int
    ---
    description = NULL    :  varchar(512)
    -> [nullable] LabMember.proj(labeler = "user_name")
    '''
    class Frame(dj.Part):
        definition = '''
        -> master
        -> DatasetVideo
        frame_num : int
        ---
        frame     : longblob
        '''
    class Label(dj.Part):
        definition = '''
        -> master
        -> DatasetVideo
        frame_num  : int
        label_name : varchar(54)
        ---
        x          : float
        y          : float
        z = NULL   : float 
        '''
    def export_labeling(self, model_num = None, bodyparts = None, disperse_labels = False, export_only_labeled = False):
        '''
        Exports labeling for PoseEstimation (for use with napari-deeplabcut)
        '''
        assert len(self) == 1, ValueError('PoseEstimationLabelSet, select only one set to export.')
        k = self.proj().fetch1()

        if export_only_labeled:
            frames = pd.DataFrame((PoseEstimationLabelSet()*PoseEstimationLabelSet.Frame() & (PoseEstimationLabelSet.Label() & k)).fetch())
        else:
            frames = pd.DataFrame((PoseEstimationLabelSet()*PoseEstimationLabelSet.Frame() & k).fetch())
        frame_labels = pd.DataFrame((PoseEstimationLabelSet()*PoseEstimationLabelSet.Label() & k).fetch())

        if model_num is None:
            folder = (Path(prefs['local_paths'][0])/'pose_estimation_models')/f'pose_label_set_num_{k["pose_label_set_num"]}'
        else:
            folder = (Path(prefs['local_paths'][0])/'pose_estimation_models')/f'model_{model_num}'
        data_path = (folder / "labeled-data") / f'label_set_{k["pose_label_set_num"]}'
        data_path.mkdir(parents=True, exist_ok=True)
        if bodyparts is None:
            bodyparts = np.unique(frame_labels.label_name.values)
        from natsort import natsorted
        bodyparts = natsorted(bodyparts) # this is an attempt to sort the labels
        labeler = frames['labeler'].iloc[0]
        from skimage.io import imsave
        from tqdm import tqdm
        todlc = []
        for i,f in tqdm(enumerate(frames.frame_num.values),desc = "Exporting labeling dataset:",total = len(frames)):
            im_name = 'im_{0:06d}_session{2}_frame{1:06d}'.format(i,f,frames.session_name.iloc[i])
            for bpart in bodyparts:
                t = (PoseEstimationLabelSet.Label & dict(
                    pose_label_set_num = k['pose_label_set_num'],
                    frame_num = f,
                    label_name = bpart)).fetch()
                x = np.nan
                y = np.nan
                if disperse_labels:
                    if i == 0:
                        x = i*20
                        y = 100
                if len(t):
                    x = t['x'][0]
                    y = t['y'][0]
                todlc.append(dict(scorer = labeler,
                                bodyparts = bpart,
                                level_0 = 'labeled-data',
                                level_1 = f'label_set_{k["pose_label_set_num"]}',
                                level_2 = f'{im_name}.png',
                                x = x,
                                y = y))
            fname = data_path/f'{im_name}.png'
            if not fname.exists():
                imsave(fname,frames.iloc[i].frame)
        df = pd.DataFrame(todlc)
        df = df.set_index(["scorer", "bodyparts","level_0","level_1","level_2"]).stack()
        df.index.set_names("coords", level=-1, inplace=True)
        df = df.unstack(["scorer", "bodyparts", "coords"])
        df.index.name = None
        df.to_hdf(data_path/f'CollectedData_{labeler}.h5',key='keypoints')
        return data_path,frames,frame_labels

    def update_labeling(self, labeling_file):
        '''
        (PoseEstimationLabelSet() & 'pose_label_set_num =3').update_labeling('filename.h5')

        Updates the labels in the PoseEstimationLabelSet from a file.
        Currently only DLC format is supported.

        Reach out if you need other formats.
         Joao Couto 2023
        '''
        dlcres = pd.read_hdf(labeling_file)
        scorer = np.unique(dlcres.columns.get_level_values(0))[0]
        bodyparts = np.unique(dlcres.columns.get_level_values(1))
        frame_nums = [int(f.split('frame')[-1].strip('.png')) 
                    for f in dlcres.reset_index()['level_2'].values]
        frame_names = dlcres.reset_index()['level_2'].values
        labels = []
        from tqdm import tqdm
        labels_to_insert = [] # insert the labels in parallel will be faster.
        labels_to_delete = [] # need to delete all labels for a frame before adding the new ones
        for iframe,frame_name in tqdm(enumerate(frame_names),desc = 'Updating labels',total = len(frame_names)):
            frame_num = int(frame_name.split('frame')[-1].strip('.png'))
            frame_key = dict(frame_num = frame_num)
            if 'session' in frame_name: # get the session name so there are no conflicting frame numbers
                 frame_key['session_name'] = frame_name.split('session')[-1].split('_frame')[0]
            frame_key = (PoseEstimationLabelSet.Frame() & self.proj().fetch1() & frame_key).proj().fetch1()
            labels_to_delete.extend((PoseEstimationLabelSet.Label() & frame_key).proj().fetch(as_dict = True))
            for dlcname in bodyparts:
                if np.isnan(dlcres[scorer][dlcname].iloc[iframe]['x']):
                    continue # if it is NaN, don't add
                if dlcres[scorer][dlcname].iloc[iframe]['x'] == 0 and dlcres[scorer][dlcname].iloc[iframe]['y'] == 0:
                    continue # if the label is at 0,0  don't add
                label = dict(dict(frame_key,label_name = dlcname),
                            label_name = dlcname,
                            x = dlcres[scorer][dlcname].iloc[iframe]['x'],
                            y = dlcres[scorer][dlcname].iloc[iframe]['y'])
                labels_to_insert.append(label)
        (PoseEstimationLabelSet.Label() & labels_to_delete).delete(force = True) # ask the user to confirm
        PoseEstimationLabelSet.Label.insert(labels_to_insert)

export_labeling(model_num=None, bodyparts=None, disperse_labels=False, export_only_labeled=False)

Exports labeling for PoseEstimation (for use with napari-deeplabcut)

Source code in labdata/schema/video.py
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def export_labeling(self, model_num = None, bodyparts = None, disperse_labels = False, export_only_labeled = False):
    '''
    Exports labeling for PoseEstimation (for use with napari-deeplabcut)
    '''
    assert len(self) == 1, ValueError('PoseEstimationLabelSet, select only one set to export.')
    k = self.proj().fetch1()

    if export_only_labeled:
        frames = pd.DataFrame((PoseEstimationLabelSet()*PoseEstimationLabelSet.Frame() & (PoseEstimationLabelSet.Label() & k)).fetch())
    else:
        frames = pd.DataFrame((PoseEstimationLabelSet()*PoseEstimationLabelSet.Frame() & k).fetch())
    frame_labels = pd.DataFrame((PoseEstimationLabelSet()*PoseEstimationLabelSet.Label() & k).fetch())

    if model_num is None:
        folder = (Path(prefs['local_paths'][0])/'pose_estimation_models')/f'pose_label_set_num_{k["pose_label_set_num"]}'
    else:
        folder = (Path(prefs['local_paths'][0])/'pose_estimation_models')/f'model_{model_num}'
    data_path = (folder / "labeled-data") / f'label_set_{k["pose_label_set_num"]}'
    data_path.mkdir(parents=True, exist_ok=True)
    if bodyparts is None:
        bodyparts = np.unique(frame_labels.label_name.values)
    from natsort import natsorted
    bodyparts = natsorted(bodyparts) # this is an attempt to sort the labels
    labeler = frames['labeler'].iloc[0]
    from skimage.io import imsave
    from tqdm import tqdm
    todlc = []
    for i,f in tqdm(enumerate(frames.frame_num.values),desc = "Exporting labeling dataset:",total = len(frames)):
        im_name = 'im_{0:06d}_session{2}_frame{1:06d}'.format(i,f,frames.session_name.iloc[i])
        for bpart in bodyparts:
            t = (PoseEstimationLabelSet.Label & dict(
                pose_label_set_num = k['pose_label_set_num'],
                frame_num = f,
                label_name = bpart)).fetch()
            x = np.nan
            y = np.nan
            if disperse_labels:
                if i == 0:
                    x = i*20
                    y = 100
            if len(t):
                x = t['x'][0]
                y = t['y'][0]
            todlc.append(dict(scorer = labeler,
                            bodyparts = bpart,
                            level_0 = 'labeled-data',
                            level_1 = f'label_set_{k["pose_label_set_num"]}',
                            level_2 = f'{im_name}.png',
                            x = x,
                            y = y))
        fname = data_path/f'{im_name}.png'
        if not fname.exists():
            imsave(fname,frames.iloc[i].frame)
    df = pd.DataFrame(todlc)
    df = df.set_index(["scorer", "bodyparts","level_0","level_1","level_2"]).stack()
    df.index.set_names("coords", level=-1, inplace=True)
    df = df.unstack(["scorer", "bodyparts", "coords"])
    df.index.name = None
    df.to_hdf(data_path/f'CollectedData_{labeler}.h5',key='keypoints')
    return data_path,frames,frame_labels

update_labeling(labeling_file)

(PoseEstimationLabelSet() & 'pose_label_set_num =3').update_labeling('filename.h5')

Updates the labels in the PoseEstimationLabelSet from a file. Currently only DLC format is supported.

Reach out if you need other formats. Joao Couto 2023

Source code in labdata/schema/video.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def update_labeling(self, labeling_file):
    '''
    (PoseEstimationLabelSet() & 'pose_label_set_num =3').update_labeling('filename.h5')

    Updates the labels in the PoseEstimationLabelSet from a file.
    Currently only DLC format is supported.

    Reach out if you need other formats.
     Joao Couto 2023
    '''
    dlcres = pd.read_hdf(labeling_file)
    scorer = np.unique(dlcres.columns.get_level_values(0))[0]
    bodyparts = np.unique(dlcres.columns.get_level_values(1))
    frame_nums = [int(f.split('frame')[-1].strip('.png')) 
                for f in dlcres.reset_index()['level_2'].values]
    frame_names = dlcres.reset_index()['level_2'].values
    labels = []
    from tqdm import tqdm
    labels_to_insert = [] # insert the labels in parallel will be faster.
    labels_to_delete = [] # need to delete all labels for a frame before adding the new ones
    for iframe,frame_name in tqdm(enumerate(frame_names),desc = 'Updating labels',total = len(frame_names)):
        frame_num = int(frame_name.split('frame')[-1].strip('.png'))
        frame_key = dict(frame_num = frame_num)
        if 'session' in frame_name: # get the session name so there are no conflicting frame numbers
             frame_key['session_name'] = frame_name.split('session')[-1].split('_frame')[0]
        frame_key = (PoseEstimationLabelSet.Frame() & self.proj().fetch1() & frame_key).proj().fetch1()
        labels_to_delete.extend((PoseEstimationLabelSet.Label() & frame_key).proj().fetch(as_dict = True))
        for dlcname in bodyparts:
            if np.isnan(dlcres[scorer][dlcname].iloc[iframe]['x']):
                continue # if it is NaN, don't add
            if dlcres[scorer][dlcname].iloc[iframe]['x'] == 0 and dlcres[scorer][dlcname].iloc[iframe]['y'] == 0:
                continue # if the label is at 0,0  don't add
            label = dict(dict(frame_key,label_name = dlcname),
                        label_name = dlcname,
                        x = dlcres[scorer][dlcname].iloc[iframe]['x'],
                        y = dlcres[scorer][dlcname].iloc[iframe]['y'])
            labels_to_insert.append(label)
    (PoseEstimationLabelSet.Label() & labels_to_delete).delete(force = True) # ask the user to confirm
    PoseEstimationLabelSet.Label.insert(labels_to_insert)

PoseEstimationModel

Bases: Manual

Source code in labdata/schema/video.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@analysisschema
class PoseEstimationModel(dj.Manual):
    definition = '''
    model_num                : int
    ---
    algorithm_name           : varchar(24)    # Algorithm for pose estimation
    -> [nullable] AnalysisFile                # zipped model; no videos.
    -> [nullable] PoseEstimationLabelSet 
    parameters_dict = NULL   : varchar(2000)  # parameters json formatted dictionary
    training_datetime = NULL : datetime
    container_name = NULL    : varchar(64)    # Name of the container to use
    code_link = NULL         : varchar(300)   # link to the github of the algorithm
    '''

    def insert_model(self, model_num, 
                     model_folder=None,
                     pose_label_set_num = None,
                     algorithm_name = None,
                     parameters = None,
                     training_datetime=None,
                     container_name = None,
                     code_link = None):    
        import shutil
        if training_datetime is None:
            today = datetime.now()
        else:
            today = training_datetime
        dataset_name = datetime.strftime(today,'%Y%m%d_%H%M%S')

        # check if this model_number exists for another pose_label_set_num
        allmodels = pd.DataFrame(PoseEstimationModel.fetch())
        sel = allmodels[(allmodels.model_num.values == model_num) & (allmodels.pose_label_set_num.values != pose_label_set_num)]
        if len(sel):
            model_num = np.max(allmodels.model_num.values)+1

        if model_folder is None:
            model_folder = ((Path(prefs['local_paths'][0])/'pose_estimation_models')/f'{dataset_name}')/f'model_{model_num}'
        filepath = ((Path(prefs['local_paths'][0])/'pose_estimation_models')/f'{dataset_name}')/f'model_{model_num}'
        print(f'Creating archive {filepath}')
        shutil.make_archive(filepath, 'zip', model_folder)
        filepath = filepath.with_suffix('.zip')

        key = AnalysisFile().upload_files([filepath],dataset = dict(subject_name = 'pose_estimation_models',
                                                                  session_name = f'model_{model_num}',
                                                                  dataset_name = dataset_name))
        key = dict((AnalysisFile & key).proj().fetch1(),
                          model_num = model_num,
                          algorithm_name = algorithm_name,
                          training_datetime = today,
                          pose_label_set_num = pose_label_set_num,
                          parameters_dict = json.dumps(parameters) if not parameters is None else None,
                          container_name = container_name,
                          code_link = code_link)
        if not len(PoseEstimationModel & f'model_num = {model_num}'):
            self.insert1(key)
        else:
            print(f'Model {model_num} already exists. Updating but keeping last version in AWS.')        
            oldentry = (PoseEstimationModel & f'model_num = {model_num}').fetch1()
            for k in key.keys():
                if key[k] is None:
                    key[k] = oldentry[k]
            self.update1(key)
        # need to add a model evaluation part table here
    def get_model(self):
        filepath = (AnalysisFile & self).get()

        filepath = filepath[0]
        if not (filepath.parent/'config.yaml').exists():
            import shutil 
            shutil.unpack_archive(filepath,extract_dir = filepath.parent)
        return filepath.parent/'config.yaml' # return the path to the config file.

PoseEstimation

Bases: Manual

Source code in labdata/schema/video.py
232
233
234
235
236
237
238
239
240
241
242
243
@analysisschema
class PoseEstimation(dj.Manual):
    definition = '''
    -> PoseEstimationModel
    -> DatasetVideo
    label_name : varchar(54)
    ---
    x                 : longblob
    y                 : longblob
    z = NULL          : longblob
    likelihood = NULL : longblob
    '''

EphysRecording

Bases: Imported

Source code in labdata/schema/ephys.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
@dataschema
class EphysRecording(dj.Imported):
    definition = '''
    -> Dataset
    ---
    n_probes               : smallint            # number of probes
    recording_duration     : float               # duration of the recording
    recording_software     : varchar(56)         # software_version 
    '''

    class ProbeSetting(dj.Part):
        definition = '''
        -> master
        probe_num               : smallint          # probe number
        ---
        -> ProbeConfiguration
        sampling_rate           : decimal(22,14)    # sampling rate 
        '''
    class ProbeFile(dj.Part):
        definition = '''
        -> EphysRecording.ProbeSetting
        -> File                                     # binary file that contains the data
        '''

    def add_spikeglx_recording(self,key):
        '''
        Adds a recording from Dataset ap.meta files.
        '''
        allpaths = pd.DataFrame((Dataset.DataFiles() & key).fetch()).file_path.values
        paths = natsorted(list(filter( lambda x: x.endswith('.ap.meta'),
                                       allpaths)))
        keys = []
        local_path = Path(prefs['local_paths'][0])
        for iprobe, p in enumerate(paths):
            # add each configuration
            tmp = ProbeConfiguration().add_from_spikeglx_metadata(local_path/p)
            tt = dict(key,n_probes = len(paths),probe_num = iprobe,**tmp)
            EphysRecording.insert1(tt,
                                   ignore_extra_fields = True,
                                   skip_duplicates = True,
                                   allow_direct_insert = True)
            EphysRecording.ProbeSetting.insert1(tt,
                                                ignore_extra_fields = True,
                                                skip_duplicates = True,
                                                allow_direct_insert = True)
            # only working for spikeglx files for the moment.
            pfiles = list(filter(lambda x: f'imec{iprobe}.ap.' in x,allpaths))
            EphysRecording.ProbeFile().insert([
                dict(tt,
                     **(File() & f'file_path = "{fi}"').proj().fetch(as_dict = True)[0])
                for fi in pfiles],
                                              skip_duplicates = True,
                                              ignore_extra_fields = True,
                                              allow_direct_insert = True)
            EphysRecordingNoiseStats().populate(tt) # try to populate the NoiseStats table (this will take a couple of minutes)

    def add_nidq_events(self,key = None):
        if key is None:
            key = [k for k in self] # create a list
        if type(key) is dict:
            key = [key]
        for k in key:
            dkey = (Dataset() & key).proj().fetch1()
            kk = [dict(dkey,
                       stream_name = 'nidq'),
                  dict(dkey,
                       stream_name = 'obx')]
            if len(DatasetEvents() & kk):
                print(f' DatasetEvents for nidq are already there for {dkey}')
                continue

            allpaths = pd.DataFrame((Dataset.DataFiles() & k).fetch()).file_path.values
            paths = list(filter( lambda x: '.nidq.' in x, allpaths))
            stream_name = 'nidq'
            if not len(paths):
                # try obx files
                paths = list(filter( lambda x: '.obx.' in x, allpaths))
                stream_name = 'obx'
            file_paths = (File() & [dict(file_path = p) for p in paths])
            try:
                file_paths = file_paths.get() # download or get the path
            except ValueError:
                raise(ValueError(f'Error getting files for dataset {k}.'))
            from ..rules.ephys import extract_events_from_nidq
            events,daq = extract_events_from_nidq(file_paths)
            dkey = dict(dkey,
                        stream_name = stream_name)
            if not len(events):
                print(f'No events for key: {k}')
                continue

            DatasetEvents().insert1(dkey, allow_direct_insert =  True)
            DatasetEvents.Digital().insert([dict(dkey,**ev) for ev in events], allow_direct_insert = True)
        return 

add_spikeglx_recording(key)

Adds a recording from Dataset ap.meta files.

Source code in labdata/schema/ephys.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
def add_spikeglx_recording(self,key):
    '''
    Adds a recording from Dataset ap.meta files.
    '''
    allpaths = pd.DataFrame((Dataset.DataFiles() & key).fetch()).file_path.values
    paths = natsorted(list(filter( lambda x: x.endswith('.ap.meta'),
                                   allpaths)))
    keys = []
    local_path = Path(prefs['local_paths'][0])
    for iprobe, p in enumerate(paths):
        # add each configuration
        tmp = ProbeConfiguration().add_from_spikeglx_metadata(local_path/p)
        tt = dict(key,n_probes = len(paths),probe_num = iprobe,**tmp)
        EphysRecording.insert1(tt,
                               ignore_extra_fields = True,
                               skip_duplicates = True,
                               allow_direct_insert = True)
        EphysRecording.ProbeSetting.insert1(tt,
                                            ignore_extra_fields = True,
                                            skip_duplicates = True,
                                            allow_direct_insert = True)
        # only working for spikeglx files for the moment.
        pfiles = list(filter(lambda x: f'imec{iprobe}.ap.' in x,allpaths))
        EphysRecording.ProbeFile().insert([
            dict(tt,
                 **(File() & f'file_path = "{fi}"').proj().fetch(as_dict = True)[0])
            for fi in pfiles],
                                          skip_duplicates = True,
                                          ignore_extra_fields = True,
                                          allow_direct_insert = True)
        EphysRecordingNoiseStats().populate(tt) # try to populate the NoiseStats table (this will take a couple of minutes)

SpikeSorting

Bases: Manual

Source code in labdata/schema/ephys.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
@analysisschema
class SpikeSorting(dj.Manual):
    definition = '''
    -> EphysRecording.ProbeSetting
    -> SpikeSortingParams
    ---
    algorithm_version         = NULL        : varchar(56)    # version of the algorithm used
    sorting_datetime          = NULL        : datetime       # date of the spike sorting analysis
    n_pre_samples             = NULL        : smallint       # to compute the waveform time 
    n_sorted_units            = NULL        : int            # number of sorted units
    n_detected_spikes         = NULL        : int            # number of detected spikes
    sorting_channel_indices   = NULL        : longblob       # channel_map
    sorting_channel_coords    = NULL        : longblob       # channel_positions
    additional_params = NULL                : varchar(2000)  # additional json formatted parameters
    -> [nullable] AnalysisFile.proj(features_file='file_path',features_storage='storage')
    -> [nullable] AnalysisFile.proj(waveforms_file='file_path',waveforms_storage='storage')
    container_version         = NULL        : varchar(512)    # name and version of the container
   '''
    # For each sorting, create a "features.hdf5" file that has the: (this file can be > 4Gb)
    #    - template features
    #    - cluster indices
    #    - whitening_matrix
    #    - templates 
    # For each sorting create a "waveforms.hdf5" file that has the: (this file can be > 10Gb)
    #   - filtered waveform samples for each unit (1000 per unit)
    #   - indices of the extracted waveforms

    class Segment(dj.Part):
        definition = '''
        -> master
        segment_num               : int  # number of the segment
        ---
        offset_samples            : int         # offset where the traces comes from
        segment                   : longblob    # 2 second segment of data in the AP band
        '''

    class Unit(dj.Part):
        definition = '''
        -> master
        unit_id                  : int       # cluster id
        ---
        spike_times              : longblob  # in samples (uint64)
        spike_positions  = NULL  : longblob  # spike position in the electrode (float32)
        spike_amplitudes = NULL  : longblob  # spike template amplitudes (float32)
        '''
        def get_sampling_rates(self):
            sampling_rates = (self*EphysRecording.ProbeSetting()).fetch('sampling_rate')
            return [float(s) for s in sampling_rates] # cast to float if decimal

        def get_units_with_waveforms(self, return_seconds = True, interp_method = 'cubic-spline'):
            units = []
            sampling_rates = self.get_sampling_rates()
            # get the interpolation functions for all experiments if return_seconds = True
            if return_seconds:
                exps = (EphysRecording.ProbeSetting() & self.proj()).proj().fetch(as_dict = True)
                interpolations = dict()
                for e in exps:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**e)
                    try:
                        interpolations[k] = (StreamSync() & e).apply(None, method = interp_method)
                    except:
                        interpolations[k] = None
            for p,u,r in zip(self.proj(),self,sampling_rates):
                w = (SpikeSorting.Waveforms() & p)
                if len(w):
                    w = w.fetch('waveform_median')[0]
                else:
                    w = None
                units.append(dict(u, waveform_median = w))
                if return_seconds:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**p)
                    if interpolations[k] is None:
                        units[-1]['spike_times'] = units[-1]['spike_times'].astype(np.float32)/np.float32(r)
                    else:
                        units[-1]['spike_times'] = interpolations[k](units[-1]['spike_times'].astype(np.float32))
            return units

        def get_spike_times(self, as_dict = True, return_seconds = True, extra_keys = [], warn = True, include_metrics = False,
                            interp_method = 'cubic-spline'):
            '''
spike_times = get_spike_times()

Gets spike times corrected if the sync is applied.
    defaults:
        as_dict = True
        return_seconds = True
        extra_keys = []
        warn = True.       # show a warning when using just the sampling rate
        include_metrics = False
            '''
            if return_seconds:
                exps = (EphysRecording.ProbeSetting() & self.proj()).proj().fetch(as_dict = True)
                interpolations = dict()
                for e in exps:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**e)
                    try:
                        probe_num = e['probe_num']
                        interpolations[k] = (StreamSync() & e & f'stream_name = "imec{probe_num}"').apply(None, warn = False,
                                                                                                          method = interp_method)
                    except AssertionError as err:
                        if warn:
                            import warnings
                            warnings.warn(f"Using the sampling rate for spike times", RuntimeWarning)
                        rate = (EphysRecording.ProbeSetting() & e).fetch1('sampling_rate')
                        interpolations[k] = lambda x: x/float(rate)
            keys = ['subject_name','session_name','dataset_name','probe_num','parameter_set_num','unit_id','spike_times'] + extra_keys
            if include_metrics: # add the metrics keys
                keys += [attr for attr in UnitMetrics.heading.attributes if not attr in keys]
            if include_metrics:
                units = (self*UnitMetrics).fetch(*keys,as_dict=True)
            else:
                units = self.fetch(*keys,as_dict=True)
            if return_seconds:
                for u in units:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**u)
                    u['spike_times'] = interpolations[k](u['spike_times'])
            if as_dict:
                return units
            return [u['spike_times'] for u in units]

    class Waveforms(dj.Part):
        definition = '''
        -> SpikeSorting.Unit
        ---
        waveform_median   :  longblob         # average waveform (gain corrected in microvolt - float32)
        '''

    def delete(
            self,
            transaction = True,
            safemode  = None,
            force_parts = False,
            keep_analysis = False):

        files = [f['waveforms_file'] for f in self]
        files += [f['features_file'] for f in self]
        super().delete(transaction = transaction,
                       safemode = safemode,
                       force_parts = force_parts)
        if keep_analysis:
            print(f'Kept {files}.')
            return
        if len(self) == 0:
            if len(files):
                (AnalysisFile() & [f'file_path = "{t}"' for t in files]).delete(force_parts=force_parts,
                                                                                safemode = safemode) 

    class IntermediateFiles(dj.Part):
        definition = '''
        -> master
        -> AnalysisFile
        '''

    class LinkedDatasets(dj.Part):
        definition = '''
        -> master
        -> Dataset.proj(linked_session_name='session_name',linked_dataset_session='dataset_name')
    '''

Unit

Bases: Part

Source code in labdata/schema/ephys.py
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    class Unit(dj.Part):
        definition = '''
        -> master
        unit_id                  : int       # cluster id
        ---
        spike_times              : longblob  # in samples (uint64)
        spike_positions  = NULL  : longblob  # spike position in the electrode (float32)
        spike_amplitudes = NULL  : longblob  # spike template amplitudes (float32)
        '''
        def get_sampling_rates(self):
            sampling_rates = (self*EphysRecording.ProbeSetting()).fetch('sampling_rate')
            return [float(s) for s in sampling_rates] # cast to float if decimal

        def get_units_with_waveforms(self, return_seconds = True, interp_method = 'cubic-spline'):
            units = []
            sampling_rates = self.get_sampling_rates()
            # get the interpolation functions for all experiments if return_seconds = True
            if return_seconds:
                exps = (EphysRecording.ProbeSetting() & self.proj()).proj().fetch(as_dict = True)
                interpolations = dict()
                for e in exps:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**e)
                    try:
                        interpolations[k] = (StreamSync() & e).apply(None, method = interp_method)
                    except:
                        interpolations[k] = None
            for p,u,r in zip(self.proj(),self,sampling_rates):
                w = (SpikeSorting.Waveforms() & p)
                if len(w):
                    w = w.fetch('waveform_median')[0]
                else:
                    w = None
                units.append(dict(u, waveform_median = w))
                if return_seconds:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**p)
                    if interpolations[k] is None:
                        units[-1]['spike_times'] = units[-1]['spike_times'].astype(np.float32)/np.float32(r)
                    else:
                        units[-1]['spike_times'] = interpolations[k](units[-1]['spike_times'].astype(np.float32))
            return units

        def get_spike_times(self, as_dict = True, return_seconds = True, extra_keys = [], warn = True, include_metrics = False,
                            interp_method = 'cubic-spline'):
            '''
spike_times = get_spike_times()

Gets spike times corrected if the sync is applied.
    defaults:
        as_dict = True
        return_seconds = True
        extra_keys = []
        warn = True.       # show a warning when using just the sampling rate
        include_metrics = False
            '''
            if return_seconds:
                exps = (EphysRecording.ProbeSetting() & self.proj()).proj().fetch(as_dict = True)
                interpolations = dict()
                for e in exps:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**e)
                    try:
                        probe_num = e['probe_num']
                        interpolations[k] = (StreamSync() & e & f'stream_name = "imec{probe_num}"').apply(None, warn = False,
                                                                                                          method = interp_method)
                    except AssertionError as err:
                        if warn:
                            import warnings
                            warnings.warn(f"Using the sampling rate for spike times", RuntimeWarning)
                        rate = (EphysRecording.ProbeSetting() & e).fetch1('sampling_rate')
                        interpolations[k] = lambda x: x/float(rate)
            keys = ['subject_name','session_name','dataset_name','probe_num','parameter_set_num','unit_id','spike_times'] + extra_keys
            if include_metrics: # add the metrics keys
                keys += [attr for attr in UnitMetrics.heading.attributes if not attr in keys]
            if include_metrics:
                units = (self*UnitMetrics).fetch(*keys,as_dict=True)
            else:
                units = self.fetch(*keys,as_dict=True)
            if return_seconds:
                for u in units:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**u)
                    u['spike_times'] = interpolations[k](u['spike_times'])
            if as_dict:
                return units
            return [u['spike_times'] for u in units]

get_spike_times(as_dict=True, return_seconds=True, extra_keys=[], warn=True, include_metrics=False, interp_method='cubic-spline')

spike_times = get_spike_times()

Gets spike times corrected if the sync is applied. defaults: as_dict = True return_seconds = True extra_keys = [] warn = True. # show a warning when using just the sampling rate include_metrics = False

Source code in labdata/schema/ephys.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
        def get_spike_times(self, as_dict = True, return_seconds = True, extra_keys = [], warn = True, include_metrics = False,
                            interp_method = 'cubic-spline'):
            '''
spike_times = get_spike_times()

Gets spike times corrected if the sync is applied.
    defaults:
        as_dict = True
        return_seconds = True
        extra_keys = []
        warn = True.       # show a warning when using just the sampling rate
        include_metrics = False
            '''
            if return_seconds:
                exps = (EphysRecording.ProbeSetting() & self.proj()).proj().fetch(as_dict = True)
                interpolations = dict()
                for e in exps:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**e)
                    try:
                        probe_num = e['probe_num']
                        interpolations[k] = (StreamSync() & e & f'stream_name = "imec{probe_num}"').apply(None, warn = False,
                                                                                                          method = interp_method)
                    except AssertionError as err:
                        if warn:
                            import warnings
                            warnings.warn(f"Using the sampling rate for spike times", RuntimeWarning)
                        rate = (EphysRecording.ProbeSetting() & e).fetch1('sampling_rate')
                        interpolations[k] = lambda x: x/float(rate)
            keys = ['subject_name','session_name','dataset_name','probe_num','parameter_set_num','unit_id','spike_times'] + extra_keys
            if include_metrics: # add the metrics keys
                keys += [attr for attr in UnitMetrics.heading.attributes if not attr in keys]
            if include_metrics:
                units = (self*UnitMetrics).fetch(*keys,as_dict=True)
            else:
                units = self.fetch(*keys,as_dict=True)
            if return_seconds:
                for u in units:
                    k = '{subject_name}_{session_name}_{dataset_name}_{probe_num}'.format(**u)
                    u['spike_times'] = interpolations[k](u['spike_times'])
            if as_dict:
                return units
            return [u['spike_times'] for u in units]

SpikeSortingParams

Bases: Manual

Source code in labdata/schema/ephys.py
221
222
223
224
225
226
227
228
229
230
@analysisschema
class SpikeSortingParams(dj.Manual):
    definition = '''
    parameter_set_num              : int            # number of the parameters set
    ---
    algorithm_name                 : varchar(64)    # preprocessing  and spike sorting algorithm 
    parameter_description = NULL   : varchar(256)   # description or specific use case
    parameters_dict                : varchar(2000)  # parameters json formatted dictionary
    code_link = NULL               : varchar(300)   # the software that preprocesses and sorts
    '''

UnitMetrics

Bases: Computed

Source code in labdata/schema/ephys.py
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
@analysisschema 
class UnitMetrics(dj.Computed):
   default_container = 'labdata-spks'
   # Compute the metrics from the each unit,
   # so we can recompute and add new ones if needed and not depend on the clustering
   definition = '''
   -> SpikeSorting.Unit
   ---
   num_spikes                         : int
   depth                    = NULL    : double
   position                 = NULL    : blob
   shank                    = NULL    : int
   channel_index            = NULL    : int
   n_electrodes_spanned     = NULL    : int
   firing_rate              = NULL    : float
   isi_contamination        = NULL    : float
   isi_contamination_hill   = NULL    : float
   amplitude_cutoff         = NULL    : float
   presence_ratio           = NULL    : float
   depth_drift_range        = NULL    : float
   depth_drift_fluctuation  = NULL    : float
   depth_drift_start_to_end = NULL    : float
   spike_amplitude          = NULL    : float
   spike_duration           = NULL    : float
   trough_time              = NULL    : float
   trough_amplitude         = NULL    : float
   fw3m                     = NULL    : float
   trough_gradient          = NULL    : float
   peak_gradient            = NULL    : float
   peak_time                = NULL    : float
   peak_amplitude           = NULL    : float
   polarity                 = NULL    : tinyint
   active_electrodes        = NULL    : blob
   '''
   def make(self, key):
       dat = (SpikeSorting.Unit & key).get_units_with_waveforms()
       assert len(dat) == 1, ValueError('Need to select only one unit')
       dat = dat[0]

       from spks.metrics import (isi_contamination,
                                 isi_contamination_hill,
                                 amplitude_cutoff,
                                 presence_ratio,
                                 firing_rate,
                                 depth_stability)
       from spks.waveforms import waveforms_position, compute_waveform_metrics

       kk = {k:dat[k] for k in ['subject_name','session_name','dataset_name','probe_num']}
       channel_coords,srate,wpre = (EphysRecording.ProbeSetting()*SpikeSorting() & kk 
                                    & f'parameter_set_num = {key["parameter_set_num"]}').fetch1(
                                    'sorting_channel_coords','sampling_rate','n_pre_samples')
       channel_shanks,duration = (EphysRecording()*EphysRecording.ProbeSetting()*
                                 ProbeConfiguration() & kk).fetch1('channel_shank','recording_duration')

       metrics = dict(key)

       metrics['num_spikes'] = len(dat['spike_times'])
       if metrics['num_spikes'] > 5: # skip if less than 5 spikes
           metrics['firing_rate'] = firing_rate(dat['spike_times'],0, t_max = duration)
       if metrics['num_spikes'] > 50: # skip if less than 50 spikes
           metrics['isi_contamination'] = isi_contamination(dat['spike_times'], T = duration)
           metrics['isi_contamination_hill'] = isi_contamination_hill(dat['spike_times'],
                                                                      T = duration)
           metrics['amplitude_cutoff'] = amplitude_cutoff(dat['spike_amplitudes'])
           metrics['presence_ratio'] = presence_ratio(dat['spike_times'],
                                                      t_min = 0, t_max = duration)
           metrics['depth_drift_range'],metrics['depth_drift_fluctuation'],metrics['depth_drift_start_to_end'] = depth_stability(dat['spike_times'], dat['spike_positions'][:,1], tmax = duration)
           if not dat['waveform_median'] is None:
               waves = dat['waveform_median']

               pos,channel,active_idx = waveforms_position(np.expand_dims(dat['waveform_median'],
                                                                          axis = 0),
                                                           channel_positions = channel_coords,
                                                           active_electrode_threshold = 3)
               if not np.all(np.isfinite(pos)): 
                   # this can happen when there is noise in the waveform estimate, choose the position of the peak channel then
                   pos = channel_coords[channel]
                   active_idx = [channel]
               metrics['n_electrodes_spanned'] = len(active_idx[0])
               if len(active_idx):
                   metrics['active_electrodes'] = np.array(active_idx).astype(int)
                   metrics['depth'] = pos[0][1] # TODO: estimate position from 6 channels around the peak channel
                   metrics['position'] = pos[0]
                   metrics['shank'] = channel_shanks[channel[0]]
                   metrics['channel_index'] = channel[0]
                   # the com only works if it is done only for the values that have spikes
                   wavemetrics = compute_waveform_metrics(waves[:,channel[0]],
                                                          wpre, float(srate))
                   metrics = dict(metrics,**wavemetrics)
                   metrics['spike_amplitude'] = np.abs(metrics['trough_amplitude']-metrics['peak_amplitude']) 
       self.insert1(metrics,skip_duplicates = True)

UnitCount

Bases: Computed

Source code in labdata/schema/ephys.py
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
@analysisschema 
class UnitCount(dj.Computed):
    definition = '''
   -> SpikeSorting
   -> UnitCountCriteria
   ---
    all : int
    sua : int
    mua : int
   '''
    class Unit(dj.Part):
        definition = '''
        -> master
        -> UnitMetrics
        ---
        passes  : tinyint
        '''

    def make(self,key):
        allu = pd.DataFrame((UnitMetrics() & key).fetch())
        criteria = (UnitCountCriteria() &
                    f"unit_criteria_id = {key['unit_criteria_id']}").fetch('sua_criteria')[0]
        suaidx = _apply_unit_criteria(allu, criteria)
        sua = np.sum(suaidx)
        muacriteria = (UnitCountCriteria() &
                       f"unit_criteria_id = {key['unit_criteria_id']}").fetch('mua_criteria')[0]
        if muacriteria is None:
            mua = len(allu) 
        else:
            mua = np.sum(_apply_unit_criteria(allu, muacriteria))
        mua -= sua
        # Code for parsing the unit count criteria missing here.
        unitcounts = dict(key,
                          all = len(allu),
                          sua = sua,
                          mua = mua)        
        self.insert1(unitcounts)
        # select only a projection later, for now we ignore the extra fields
        allu['passes'] = suaidx.astype(int)
        keys = [dict(a,unit_criteria_id = key['unit_criteria_id']) for i,a in allu.iterrows()]
        self.Unit.insert(keys, ignore_extra_fields = True) # add to the Unit part table

Widefield

Bases: Imported

Table for widefield one-photon imaging data.

This table stores metadata about widefield imaging recordings including: - Frame dimensions and counts - Frame rate - Optical parameters (magnification, objective, pixel scale) - Reference to raw data file - Imaging software details

The table includes a Part table for storing different projections of the data (mean, std, var, max).

Source code in labdata/schema/onephoton.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@dataschema
class Widefield(dj.Imported):
    '''Table for widefield one-photon imaging data.

    This table stores metadata about widefield imaging recordings including:
    - Frame dimensions and counts
    - Frame rate
    - Optical parameters (magnification, objective, pixel scale)
    - Reference to raw data file
    - Imaging software details

    The table includes a Part table for storing different projections of the data
    (mean, std, var, max).
    '''
    definition = '''
    -> Dataset
    ---
    n_channels             : smallint            # number of channels
    n_frames               : int                 # duration of the recording
    width                  : int                 # width of each frame
    height                 : int                 # height of each frame
    frame_rate             : double              # frame rate
    -> File                                      # path to the stack
    magnification = NULL   : double              # magnification
    objective_angle = NULL : double              # angle
    objective = NULL       : varchar(32)         # objective
    um_per_pixel = NULL    : blob                # XY scale conversion factors
    imaging_software       : varchar(32)         # software and version
    '''

    def open(self):
        '''Opens the widefield imaging data file.

        Returns
        -------
        zarr.Array
            The opened zarr array containing the widefield imaging data.
            Data is stored in a compressed zarr format with dimensions:
            [frames, channels, height, width]
        '''
        if len(self) != 1:
            raise(ValueError(f'Select only one dataset {self.proj().fetch(as_dict = True)}.'))
        fname = (File() & (Widefield & self.proj()) & 'file_path LIKE "%.zarr.zip"').get()[0]
        return open_zarr(fname)

    class Projection(dj.Part):
        '''Part table for storing projections of widefield imaging data.

        This table stores projections (mean, std, var, max) of the 
        widefield imaging data.

        Attributes
        ----------
        proj_name : enum
            Type of projection ('mean', 'std', 'var', 'max')
        proj : longblob
            The projection data array
        '''
        definition = '''
        -> master
        proj_name           : enum('mean','std','var','max')                  
        ---
        proj      : longblob
    '''

Projection

Bases: Part

Part table for storing projections of widefield imaging data.

This table stores projections (mean, std, var, max) of the widefield imaging data.

Attributes:
  • proj_name (enum) –

    Type of projection ('mean', 'std', 'var', 'max')

  • proj (longblob) –

    The projection data array

Source code in labdata/schema/onephoton.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
class Projection(dj.Part):
    '''Part table for storing projections of widefield imaging data.

    This table stores projections (mean, std, var, max) of the 
    widefield imaging data.

    Attributes
    ----------
    proj_name : enum
        Type of projection ('mean', 'std', 'var', 'max')
    proj : longblob
        The projection data array
    '''
    definition = '''
    -> master
    proj_name           : enum('mean','std','var','max')                  
    ---
    proj      : longblob
'''

open()

Opens the widefield imaging data file.

Returns:
  • Array

    The opened zarr array containing the widefield imaging data. Data is stored in a compressed zarr format with dimensions: [frames, channels, height, width]

Source code in labdata/schema/onephoton.py
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def open(self):
    '''Opens the widefield imaging data file.

    Returns
    -------
    zarr.Array
        The opened zarr array containing the widefield imaging data.
        Data is stored in a compressed zarr format with dimensions:
        [frames, channels, height, width]
    '''
    if len(self) != 1:
        raise(ValueError(f'Select only one dataset {self.proj().fetch(as_dict = True)}.'))
    fname = (File() & (Widefield & self.proj()) & 'file_path LIKE "%.zarr.zip"').get()[0]
    return open_zarr(fname)

FixedBrain

Bases: Imported

Whole brain histology or fixed tissue. The class provides methods for: - Loading brain image data via get() method - Viewing brain data in napari via napari_open() method

Definition

file_path : str Path to the brain imaging data file num_channels : int Number of imaging channels width : int Image width in pixels height : int Image height in pixels um_per_pixel : array Microns per pixel resolution in each dimension hardware : str Imaging hardware/microscope used

Source code in labdata/schema/histology.py
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
@dataschema
class FixedBrain(dj.Imported):
    '''Whole brain histology or fixed tissue. 
    The class provides methods for:
    - Loading brain image data via get() method
    - Viewing brain data in napari via napari_open() method

    Definition
    ----------
    file_path : str
        Path to the brain imaging data file
    num_channels : int
        Number of imaging channels
    width : int 
        Image width in pixels
    height : int
        Image height in pixels
    um_per_pixel : array
        Microns per pixel resolution in each dimension
    hardware : str
        Imaging hardware/microscope used
    '''
    definition = '''
    -> Dataset
    ---
    -> [nullable] File
    num_channels = NULL          : smallint
    width = NULL                 : int
    height = NULL                : int
    um_per_pixel = NULL          : blob
    hardware  = NULL             : varchar(56)
    '''

    class Channel(dj.Part):
        definition = '''
        -> master
        channel_index : smallint
        ---
        channel_wavelength = NULL : float
        channel_description = NULL : varchar(64)
        '''

    def get(self):
        '''Get the brain imaging data.

        Returns
        -------
        array or list
            If single brain, returns array containing brain data.
            If multiple brains, returns list of arrays.
        '''
        brains = []
        for s in self:
            brains.append(open_zarr((File() & s).get()[0]))
        if len(brains)==1:
            return brains[0] # like fetch1
        return brains

    def napari_open(self, **kwargs):
        '''Open brain data in napari viewer.

        Opens the brain imaging data in a napari viewer window for visualization.
        Only one brain can be opened at a time.

        Pass channel_axis = 1 to open with color
        Returns
        -------
        napari.Viewer
            The napari viewer instance displaying the brain data

        Raises
        ------
        AssertionError
            If more than one brain is selected
        '''
        assert len(self) == 1, 'Open only one brain at a time.'
        from labdata.stacks import napari_open
        napari_open(self.get(),**kwargs)

get()

Get the brain imaging data.

Returns:
  • array or list

    If single brain, returns array containing brain data. If multiple brains, returns list of arrays.

Source code in labdata/schema/histology.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def get(self):
    '''Get the brain imaging data.

    Returns
    -------
    array or list
        If single brain, returns array containing brain data.
        If multiple brains, returns list of arrays.
    '''
    brains = []
    for s in self:
        brains.append(open_zarr((File() & s).get()[0]))
    if len(brains)==1:
        return brains[0] # like fetch1
    return brains

napari_open(**kwargs)

Open brain data in napari viewer.

Opens the brain imaging data in a napari viewer window for visualization. Only one brain can be opened at a time.

Pass channel_axis = 1 to open with color

Returns:
  • Viewer

    The napari viewer instance displaying the brain data

Raises:
  • AssertionError

    If more than one brain is selected

Source code in labdata/schema/histology.py
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
def napari_open(self, **kwargs):
    '''Open brain data in napari viewer.

    Opens the brain imaging data in a napari viewer window for visualization.
    Only one brain can be opened at a time.

    Pass channel_axis = 1 to open with color
    Returns
    -------
    napari.Viewer
        The napari viewer instance displaying the brain data

    Raises
    ------
    AssertionError
        If more than one brain is selected
    '''
    assert len(self) == 1, 'Open only one brain at a time.'
    from labdata.stacks import napari_open
    napari_open(self.get(),**kwargs)

FixedBrainTransform

Bases: Computed

Table for storing transformed fixed brain images.

This class computes and stores transformed versions of fixed brain images based on parameters from FixedBrainTransformParameters.

The transformed images are stored as TIFF files in the analysis storage location.

Definition

file_path : str Path to the transformed TIFF file in analysis storage storage : str Storage location name (default: 'analysis') um_per_pixel : array-like Resolution in microns per pixel for each dimension shape : array-like Shape of the transformed stack [T,C,X,Y] hemisphere : str Which hemisphere is included ('left', 'right', or 'both')

Methods:

Name Description
transform

Apply transformations specified in parameters to generate transformed stack

Returns

ndarray The transformed image stack

get

Load and return the transformed image stack(s)

Returns

list List of transformed image stacks as numpy arrays

Source code in labdata/schema/histology.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
@analysisschema
class FixedBrainTransform(dj.Computed):
    '''Table for storing transformed fixed brain images.

    This class computes and stores transformed versions of fixed brain images based on 
    parameters from FixedBrainTransformParameters. 

    The transformed images are stored as TIFF files in the analysis storage location.

    Definition
    ----------
    file_path : str
        Path to the transformed TIFF file in analysis storage
    storage : str
        Storage location name (default: 'analysis')
    um_per_pixel : array-like
        Resolution in microns per pixel for each dimension
    shape : array-like 
        Shape of the transformed stack [T,C,X,Y]
    hemisphere : str
        Which hemisphere is included ('left', 'right', or 'both')

    Methods
    -------
    transform(key)
        Apply transformations specified in parameters to generate transformed stack

        Returns
        -------
        ndarray
            The transformed image stack

    get()
        Load and return the transformed image stack(s)

        Returns
        -------
        list
            List of transformed image stacks as numpy arrays
    '''
    definition = '''
    -> FixedBrain
    -> FixedBrainTransformParameters
    ---
    -> AnalysisFile
    um_per_pixel = NULL : blob
    shape = NULL        : blob
    hemisphere = NULL   : varchar(5)  # left, right, both
    '''

    def transform(self,key):
        '''Transform fixed brain image according to parameters.

        Applies transformations specified in FixedBrainTransformParameters to generate
        a transformed stack. If the transform has already been computed, loads and 
        returns the existing transformed stack from storage.

        Transformation order:
        1. Downsample if specified
        2. Rotate if specified 
        3. Crop if specified
        4. Transpose dimensions if specified

        Parameters
        ----------
        key : dict
            Primary key specifying which FixedBrainTransformParameters to use

        Returns
        -------
        ndarray
            The transformed image stack. Shape depends on transformation parameters.

        '''
        if len(self & key):
            key = (self & key).fetch1()
            print(f'Transform has been computed. Fetching from {key["file_path"]}.')
            apath = (AnalysisFile() & key).get()[0]

            from tifffile import imread 
            stack = imread(apath)
            return stack
        stack = (FixedBrain() & key).get()
        params = (FixedBrainTransformParameters() & key).fetch1()
        from labdata.stacks import rotate_stack,downsample_stack
        if not params['downsample'] is None:
            stack = downsample_stack(stack,params['downsample'])
        if not params['rotate'] is None:
            stack = rotate_stack(stack,*params['rotate'])
        if not params['crop'] is None:
            A,B,C,D = params['crop']
            if A is None:
                A = [0,stack.shape[0],1]
            if B is None:
                B = [0,stack.shape[1],1]
            if C is None:
                C = [0,stack.shape[2],1]
            if D is None:
                D = [0,stack.shape[3],1]
            stack = stack[A[0]:A[1]:A[2],
                          B[0]:B[1]:B[2],
                          C[0]:C[1]:C[2],
                          D[0]:D[1]:D[2],]
        if not params['transpose'] is None:
            stack = stack.transpose(params['transpose'])
        return stack

    def get(self):
        '''Load the transformed brain stacks.

        Returns
        -------
        ndarray or list
            If only one stack exists, returns a single ndarray.
            If multiple stacks exist, returns a list of ndarrays.
        '''
        brains = []
        from tifffile import imread
        for s in self:
            brains.append(imread((AnalysisFile() & s).get()[0]))
        if len(brains)==1:
            return brains[0] # like fetch1
        return brains

    def make(self,k):
        par = (FixedBrainTransformParameters() & k).fetch1()
        origpar = (FixedBrain() & k).fetch1()

        downsample_par = np.array(par['downsample'])[np.array([0,2,3])]
        stack = self.transform(k)
        um_per_pixel = list(origpar['um_per_pixel']/downsample_par)

        folder_path = (((Path(prefs['local_paths'][0])
                            /k['subject_name']))
                            /k['session_name'])/f'brain_transform_{k["transform_id"]}'
        filepath = folder_path/f'stack_{um_per_pixel[0]}um.ome.tif'
        folder_path.mkdir(exist_ok=True)
        from tifffile import imwrite  # saving in tiff so it is easier to read
        imwrite(filepath,stack, 
                imagej = True,
                metadata={'axes': 'ZCYX'}, 
                compression ='zlib',
                compressionargs = {'level': 6})
        added = AnalysisFile().upload_files([filepath],dict(subject_name = k['subject_name'],
                                                session_name = k['session_name'],
                                                dataset_name = f'brain_transform_{k["transform_id"]}'))[0]

        to_add = dict(k,
                      um_per_pixel = um_per_pixel,
                      shape = stack.shape,
                      **added)
        self.insert1(to_add)

get()

Load the transformed brain stacks.

Returns:
  • ndarray or list

    If only one stack exists, returns a single ndarray. If multiple stacks exist, returns a list of ndarrays.

Source code in labdata/schema/histology.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def get(self):
    '''Load the transformed brain stacks.

    Returns
    -------
    ndarray or list
        If only one stack exists, returns a single ndarray.
        If multiple stacks exist, returns a list of ndarrays.
    '''
    brains = []
    from tifffile import imread
    for s in self:
        brains.append(imread((AnalysisFile() & s).get()[0]))
    if len(brains)==1:
        return brains[0] # like fetch1
    return brains

transform(key)

Transform fixed brain image according to parameters.

Applies transformations specified in FixedBrainTransformParameters to generate a transformed stack. If the transform has already been computed, loads and returns the existing transformed stack from storage.

Transformation order: 1. Downsample if specified 2. Rotate if specified 3. Crop if specified 4. Transpose dimensions if specified

Parameters:
  • key (dict) –

    Primary key specifying which FixedBrainTransformParameters to use

Returns:
  • ndarray

    The transformed image stack. Shape depends on transformation parameters.

Source code in labdata/schema/histology.py
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
def transform(self,key):
    '''Transform fixed brain image according to parameters.

    Applies transformations specified in FixedBrainTransformParameters to generate
    a transformed stack. If the transform has already been computed, loads and 
    returns the existing transformed stack from storage.

    Transformation order:
    1. Downsample if specified
    2. Rotate if specified 
    3. Crop if specified
    4. Transpose dimensions if specified

    Parameters
    ----------
    key : dict
        Primary key specifying which FixedBrainTransformParameters to use

    Returns
    -------
    ndarray
        The transformed image stack. Shape depends on transformation parameters.

    '''
    if len(self & key):
        key = (self & key).fetch1()
        print(f'Transform has been computed. Fetching from {key["file_path"]}.')
        apath = (AnalysisFile() & key).get()[0]

        from tifffile import imread 
        stack = imread(apath)
        return stack
    stack = (FixedBrain() & key).get()
    params = (FixedBrainTransformParameters() & key).fetch1()
    from labdata.stacks import rotate_stack,downsample_stack
    if not params['downsample'] is None:
        stack = downsample_stack(stack,params['downsample'])
    if not params['rotate'] is None:
        stack = rotate_stack(stack,*params['rotate'])
    if not params['crop'] is None:
        A,B,C,D = params['crop']
        if A is None:
            A = [0,stack.shape[0],1]
        if B is None:
            B = [0,stack.shape[1],1]
        if C is None:
            C = [0,stack.shape[2],1]
        if D is None:
            D = [0,stack.shape[3],1]
        stack = stack[A[0]:A[1]:A[2],
                      B[0]:B[1]:B[2],
                      C[0]:C[1]:C[2],
                      D[0]:D[1]:D[2],]
    if not params['transpose'] is None:
        stack = stack.transpose(params['transpose'])
    return stack

FixedBrainTransformAnnotation

Bases: Manual

Table for storing manual annotations of brain locations.

This table stores manually annotated points in transformed brain volumes, such as: - Probe tracks - Injection sites - Anatomical landmarks - Region boundaries

Each annotation consists of: - annotation_name: Description of what is being annotated - annotation_type: Category of annotation (e.g. 'probe_track', 'injection') - xyz: Array of x,y,z coordinates marking the annotation location

The coordinates are in pixels relative to the transformed brain volume.

Source code in labdata/schema/histology.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
@analysisschema
class FixedBrainTransformAnnotation(dj.Manual):
    '''Table for storing manual annotations of brain locations.

    This table stores manually annotated points in transformed brain volumes, such as:
    - Probe tracks
    - Injection sites 
    - Anatomical landmarks
    - Region boundaries

    Each annotation consists of:
    - annotation_name: Description of what is being annotated
    - annotation_type: Category of annotation (e.g. 'probe_track', 'injection')
    - xyz: Array of x,y,z coordinates marking the annotation location

    The coordinates are in pixels relative to the transformed brain volume.
    '''
    definition = '''
    -> FixedBrainTransform
    annotation_id : int
    ---
    annotation_name : varchar(36)
    annotation_type : varchar(36)
    xyz : blob
    '''